]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
progress on [ticket:329]
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 19 Oct 2006 07:02:04 +0000 (07:02 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 19 Oct 2006 07:02:04 +0000 (07:02 +0000)
CHANGES
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/sql_util.py
lib/sqlalchemy/topological.py [moved from lib/sqlalchemy/orm/topological.py with 98% similarity]
test/orm/mapper.py

diff --git a/CHANGES b/CHANGES
index 9975170bd217c228c8c7d54bfaffd2fae03e345b..23e3a78d83effa73aed41278617ff101f560c697 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -36,7 +36,7 @@
     methods, methods that are no longer needed.  slightly more constrained
     useage, greater emphasis on explicitness
     - the "primary_key" attribute of Table and other selectables becomes
-    a setlike ColumnCollection object; is no longer ordered or numerically
+    a setlike ColumnCollection object; is ordered but not numerically
     indexed.  a comparison clause between two pks that are derived from the 
     same underlying tables (i.e. such as two Alias objects) can be generated 
     via table1.primary_key==table2.primary_key
index 2d9a4e845118e4e233060f3badf54d756d6434c1..cea363116bd7edb7d163fc60547a2443a6d477a7 100644 (file)
@@ -19,7 +19,7 @@ from session import Session as create_session
 
 __all__ = ['relation', 'backref', 'eagerload', 'lazyload', 'noload', 'deferred', 'defer', 'undefer',
         'mapper', 'clear_mappers', 'clear_mapper', 'sql', 'class_mapper', 'object_mapper', 'MapperExtension', 'Query', 
-        'cascade_mappers', 'polymorphic_union', 'create_session', 'synonym', 'EXT_PASS'
+        'cascade_mappers', 'polymorphic_union', 'create_session', 'synonym', 'contains_eager', 'EXT_PASS'
         ]
 
 def relation(*args, **kwargs):
@@ -75,6 +75,9 @@ def noload(name):
     into a non-load."""
     return strategies.EagerLazyOption(name, lazy=None)
 
+def contains_eager(key, decorator=None):
+    return strategies.RowDecorateOption(key, decorator=decorator)
+    
 def defer(name):
     """returns a MapperOption that will convert the column property of the given 
     name into a deferred load.  Used with mapper.options()"""
index 9a6b404a0211966e64cf69312473bb824a4e0f18..872164d32558c8f433d7aa99ce7f0699c6a4d0b5 100644 (file)
@@ -72,6 +72,7 @@ class StrategizedProperty(MapperProperty):
             self._all_strategies[cls] = strategy
             return strategy
     def setup(self, querycontext, **kwargs):
+        print "SP SETUP, KEY", self.key, " STRAT IS ",  self._get_context_strategy(querycontext)
         self._get_context_strategy(querycontext).setup_query(querycontext, **kwargs)
     def execute(self, selectcontext, instance, row, identitykey, isnew):
         self._get_context_strategy(selectcontext).process_row(selectcontext, instance, row, identitykey, isnew)
@@ -92,32 +93,50 @@ class OperationContext(object):
         self.attributes = {}
         self.recursion_stack = util.Set()
         for opt in options:
-            opt.process_context(self)
+            self.accept_option(opt)
+    def accept_option(self, opt):
+        pass
 
 class MapperOption(object):
     """describes a modification to an OperationContext."""
-    def process_context(self, context):
+    def process_query_context(self, context):
         pass
-            
-class StrategizedOption(MapperOption):
-    """a MapperOption that affects which LoaderStrategy will be used for an operation
-    by a StrategizedProperty."""
+    def process_selection_context(self, context):
+        pass
+
+class PropertyOption(MapperOption):
+    """a MapperOption that is applied to a property off the mapper
+    or one of its child mappers, identified by a dot-separated key."""
     def __init__(self, key):
         self.key = key
-    def get_strategy_class(self):
-        raise NotImplementedError()
-    def process_context(self, context):
+    def process_query_property(self, context, property):
+        pass
+    def process_selection_property(self, context, property):
+        pass
+    def process_query_context(self, context):
+        self.process_query_property(context, self._get_property(context))
+    def process_selection_context(self, context):
+        self.process_selection_property(context, self._get_property(context))
+    def _get_property(self, context):
         try:
-            key = self.__key
+            prop = self.__prop
         except AttributeError:
             mapper = context.mapper
             for token in self.key.split('.'):
                 prop = mapper.props[token]
                 mapper = getattr(prop, 'mapper', None)
-            self.__key = (LoaderStrategy, prop)
-            key = self.__key
-        context.attributes[key] = self.get_strategy_class()
-                    
+            self.__prop = prop
+        return prop
+
+class StrategizedOption(PropertyOption):
+    """a MapperOption that affects which LoaderStrategy will be used for an operation
+    by a StrategizedProperty."""
+    def process_query_property(self, context, property):
+        print  "HI " + self.key + " " + property.key
+        context.attributes[(LoaderStrategy, property)] = self.get_strategy_class()
+    def get_strategy_class(self):
+        raise NotImplementedError()
+
 
 class LoaderStrategy(object):
     """describes the loading behavior of a StrategizedProperty object.  The LoaderStrategy
index e1491a4a843b9b638b694adb34b96e2b352e878f..87c276368e1b63fe7e5722aeeb54b88f043c6b22 100644 (file)
@@ -13,7 +13,7 @@ import query as querylib
 import session as sessionlib
 import weakref
 
-__all__ = ['Mapper', 'MapperExtension', 'class_mapper', 'object_mapper', 'EXT_PASS', 'SelectionContext']
+__all__ = ['Mapper', 'MapperExtension', 'class_mapper', 'object_mapper', 'EXT_PASS']
 
 # a dictionary mapping classes to their primary mappers
 mapper_registry = weakref.WeakKeyDictionary()
@@ -686,20 +686,22 @@ class Mapper(object):
         return mapper_registry[self.class_key]
 
     def is_assigned(self, instance):
-        """returns True if this mapper handles the given instance.  this is dependent
-        not only on class assignment but the optional "entity_name" parameter as well."""
+        """return True if this mapper handles the given instance.  
+        
+        this is dependent not only on class assignment but the optional "entity_name" parameter as well."""
         return instance.__class__ is self.class_ and getattr(instance, '_entity_name', None) == self.entity_name
 
     def _assign_entity_name(self, instance):
-        """assigns this Mapper's entity name to the given instance.  subsequent Mapper lookups for this
-        instance will return the primary mapper corresponding to this Mapper's class and entity name."""
+        """assign this Mapper's entity name to the given instance.  
+        
+        subsequent Mapper lookups for this instance will return the primary 
+        mapper corresponding to this Mapper's class and entity name."""
         instance._entity_name = self.entity_name
         
     def get_session(self):
-        """returns the contextual session provided by the mapper extension chain
+        """return the contextual session provided by the mapper extension chain, if any.
         
-        raises InvalidRequestError if a session cannot be retrieved from the
-        extension chain
+        raises InvalidRequestError if a session cannot be retrieved from the extension chain
         """
         self.compile()
         s = self.extension.get_session()
@@ -708,52 +710,50 @@ class Mapper(object):
         return s
     
     def has_eager(self):
-        """returns True if one of the properties attached to this Mapper is eager loading"""
+        """return True if one of the properties attached to this Mapper is eager loading"""
         return getattr(self, '_has_eager', False)
-        
     
     def instances(self, cursor, session, *mappers, **kwargs):
-        """given a cursor (ResultProxy) from an SQLEngine, returns a list of object instances
-        corresponding to the rows in the cursor."""
-        self.__log_debug("instances()")
-        self.compile()
+        """return a list of mapped instances corresponding to the rows in a given ResultProxy."""
+        return querylib.Query(self, session).instances(cursor, *mappers, **kwargs)
+
+    def identity_key_from_row(self, row):
+        """return an identity-map key for use in storing/retrieving an item from the identity map.
+
+        row - a sqlalchemy.dbengine.RowProxy instance or other map corresponding result-set
+        column names to their values within a row.
+        """
+        return (self.class_, tuple([row[column] for column in self.pks_by_table[self.mapped_table]]), self.entity_name)
         
-        context = SelectionContext(self, session, **kwargs)
+    def identity_key_from_primary_key(self, primary_key):
+        """return an identity-map key for use in storing/retrieving an item from an identity map.
         
-        result = util.UniqueAppender([])
-        if mappers:
-            otherresults = []
-            for m in mappers:
-                otherresults.append(util.UniqueAppender([]))
-                
-        for row in cursor.fetchall():
-            self._instance(context, row, result)
-            i = 0
-            for m in mappers:
-                m._instance(context, row, otherresults[i])
-                i+=1
-                
-        # store new stuff in the identity map
-        for value in context.identity_map.values():
-            session._register_persistent(value)
-            
-        if mappers:
-            return [result.data] + [o.data for o in otherresults]
-        else:
-            return result.data
+        primary_key - a list of values indicating the identifier.
+        """
+        return (self.class_, tuple(util.to_list(primary_key)), self.entity_name)
+
+    def identity_key_from_instance(self, instance):
+        """return the identity key for the given instance, based on its primary key attributes.
         
-    def identity_key(self, primary_key):
-        """returns the instance key for the given identity value.  this is a global tracking object used by the Session, and is usually available off a mapped object as instance._instance_key."""
-        return sessionlib.get_id_key(util.to_list(primary_key), self.class_, self.entity_name)
+        this value is typically also found on the instance itself under the attribute name '_instance_key'.
+        """
+        return self.identity_key_from_primary_key(self.primary_key_from_instance(instance))
+
+    def primary_key_from_instance(self, instance):
+        """return the list of primary key values for the given instance."""
+        return [self._getattrbycolumn(instance, column) for column in self.pks_by_table[self.mapped_table]]
 
     def instance_key(self, instance):
-        """returns the instance key for the given instance.  this is a global tracking object used by the Session, and is usually available off a mapped object as instance._instance_key."""
-        return self.identity_key(self.identity(instance))
+        """deprecated.  a synonym for identity_key_from_instance."""
+        return self.identity_key_from_instance(instance)
+
+    def identity_key(self, primary_key):
+        """deprecated.  a synonym for identity_key_from_primary_key."""
+        return self.identity_key_from_primary_key(primary_key)
 
     def identity(self, instance):
-        """returns the identity (list of primary key values) for the given instance.  The list of values can be fed directly into the get() method as mapper.get(*key)."""
-        return [self._getattrbycolumn(instance, column) for column in self.pks_by_table[self.mapped_table]]
-        
+        """deprecated. a synoynm for primary_key_from_instance."""
+        return self.primary_key_from_instance(instance)
 
     def _getpropbycolumn(self, column, raiseerror=True):
         try:
@@ -1090,8 +1090,6 @@ class Mapper(object):
         for prop in self.__props.values():
             prop.cascade_callable(type, object, callable_, recursive)
             
-    def _row_identity_key(self, row):
-        return sessionlib.get_row_key(row, self.class_, self.pks_by_table[self.mapped_table], self.entity_name)
 
     def get_select_mapper(self):
         """return the mapper used for issuing selects.
@@ -1117,7 +1115,7 @@ class Mapper(object):
         # been exposed to being modified by the application.
         
         populate_existing = context.populate_existing or self.always_refresh
-        identitykey = self._row_identity_key(row)
+        identitykey = self.identity_key_from_row(row)
         if context.session.has_key(identitykey):
             instance = context.session._get(identitykey)
             self.__log_debug("_instance(): using existing instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey)))
@@ -1202,36 +1200,6 @@ class Mapper(object):
 
 Mapper.logger = logging.class_logger(Mapper)
 
-class SelectionContext(OperationContext):
-    """created within the mapper.instances() method to store and share
-    state among all the Mappers and MapperProperty objects used in a load operation.
-    
-    SelectionContext contains these attributes:
-    
-    mapper - the Mapper which originated the instances() call.
-    
-    session - the Session that is relevant to the instances call.
-    
-    identity_map - a dictionary which stores newly created instances that have
-    not yet been added as persistent to the Session.
-    
-    attributes - a dictionary to store arbitrary data; eager loaders use it to
-    store additional result lists
-    
-    populate_existing - indicates if its OK to overwrite the attributes of instances
-    that were already in the Session
-    
-    version_check - indicates if mappers that have version_id columns should verify
-    that instances existing already within the Session should have this attribute compared
-    to the freshly loaded value
-    
-    """
-    def __init__(self, mapper, session, **kwargs):
-        self.populate_existing = kwargs.pop('populate_existing', False)
-        self.version_check = kwargs.pop('version_check', False)
-        self.session = session
-        self.identity_map = {}
-        super(SelectionContext, self).__init__(mapper, kwargs.pop('with_options', []), **kwargs)
 
 class MapperExtension(object):
     """base implementation for an object that provides overriding behavior to various
@@ -1378,6 +1346,8 @@ def has_identity(object):
 def has_mapper(object):
     """returns True if the given object has a mapper association"""
     return hasattr(object, '_entity_name')
+
+
         
 def object_mapper(object, raiseerror=True):
     """given an object, returns the primary Mapper associated with the object instance"""
index b436adb01452698cdbc694fc1625cfdb3314ea31..a7021d4722506ad21ec84ba3a7ff26f9b64b217d 100644 (file)
@@ -5,11 +5,13 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 import session as sessionlib
-from sqlalchemy import sql, util, exceptions, sql_util
+from sqlalchemy import sql, util, exceptions, sql_util, logging
 
 import mapper
 from interfaces import OperationContext
 
+__all__ = ['Query', 'QueryContext', 'SelectionContext']
+
 class Query(object):
     """encapsulates the object-fetching operations provided by Mappers."""
     def __init__(self, class_or_mapper, session=None, entity_name=None, lockmode=None, with_options=None, **kwargs):
@@ -244,7 +246,7 @@ class Query(object):
 
     def select_text(self, text, **params):
         t = sql.text(text)
-        return self.instances(t, params=params)
+        return self.execute(t, params=params)
 
     def options(self, *args, **kwargs):
         """returns a new Query object using the given MapperOptions."""
@@ -268,12 +270,43 @@ class Query(object):
         else:
             raise AttributeError(key)
 
-    def instances(self, clauseelement, params=None, *args, **kwargs):
+    def execute(self, clauseelement, params=None, *args, **kwargs):
         result = self.session.execute(self.mapper, clauseelement, params=params)
         try:
             return self.mapper.instances(result, self.session, with_options=self.with_options, **kwargs)
         finally:
             result.close()
+
+    def instances(self, cursor, *mappers, **kwargs):
+        """return a list of mapped instances corresponding to the rows in a given ResultProxy."""
+        self.__log_debug("instances()")
+
+        session = self.session
+        
+        context = SelectionContext(self.mapper, session, **kwargs)
+
+        result = util.UniqueAppender([])
+        if mappers:
+            otherresults = []
+            for m in mappers:
+                otherresults.append(util.UniqueAppender([]))
+
+        for row in cursor.fetchall():
+            self.mapper._instance(context, row, result)
+            i = 0
+            for m in mappers:
+                m._instance(context, row, otherresults[i])
+                i+=1
+
+        # store new stuff in the identity map
+        for value in context.identity_map.values():
+            session._register_persistent(value)
+
+        if mappers:
+            return [result.data] + [o.data for o in otherresults]
+        else:
+            return result.data
+
         
     def _get(self, key, ident=None, reload=False, lockmode=None):
         lockmode = lockmode or self.lockmode
@@ -308,7 +341,7 @@ class Query(object):
         statement.use_labels = True
         if params is None:
             params = {}
-        return self.instances(statement, params=params, **kwargs)
+        return self.execute(statement, params=params, **kwargs)
 
     def _should_nest(self, querycontext):
         """return True if the given statement options indicate that we should "nest" the
@@ -397,6 +430,11 @@ class Query(object):
         
         return statement
 
+    def __log_debug(self, msg):
+        self.logger.debug(msg)
+
+Query.logger = logging.class_logger(Query)
+
 class QueryContext(OperationContext):
     """created within the Query.compile() method to store and share
     state among all the Mappers and MapperProperty objects used in a query construction."""
@@ -412,4 +450,40 @@ class QueryContext(OperationContext):
         super(QueryContext, self).__init__(query.mapper, query.with_options, **kwargs)
     def select_args(self):
         return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct}
+    def accept_option(self, opt):
+        opt.process_query_context(self)
+
+
+class SelectionContext(OperationContext):
+    """created within the query.instances() method to store and share
+    state among all the Mappers and MapperProperty objects used in a load operation.
+
+    SelectionContext contains these attributes:
+
+    mapper - the Mapper which originated the instances() call.
+
+    session - the Session that is relevant to the instances call.
+
+    identity_map - a dictionary which stores newly created instances that have
+    not yet been added as persistent to the Session.
+
+    attributes - a dictionary to store arbitrary data; eager loaders use it to
+    store additional result lists
+
+    populate_existing - indicates if its OK to overwrite the attributes of instances
+    that were already in the Session
+
+    version_check - indicates if mappers that have version_id columns should verify
+    that instances existing already within the Session should have this attribute compared
+    to the freshly loaded value
+
+    """
+    def __init__(self, mapper, session, **kwargs):
+        self.populate_existing = kwargs.pop('populate_existing', False)
+        self.version_check = kwargs.pop('version_check', False)
+        self.session = session
+        self.identity_map = {}
+        super(SelectionContext, self).__init__(mapper, kwargs.pop('with_options', []), **kwargs)
+    def accept_option(self, opt):
+        opt.process_selection_context(self)
         
\ No newline at end of file
index 3ec3044b8d689b276b1bdca7ad7fed55119837bd..ca64faebe8e62bddcf2489b9cbce798d115a7c1b 100644 (file)
@@ -140,15 +140,17 @@ class Session(object):
         self.uow.echo = echo
             
     def mapper(self, class_, entity_name=None):
-        """given an Class, returns the primary Mapper responsible for persisting it"""
+        """given an Class, return the primary Mapper responsible for persisting it"""
         return class_mapper(class_, entity_name = entity_name)
     def bind_mapper(self, mapper, bindto):
-        """binds the given Mapper to the given Engine or Connection.  All subsequent operations involving this
-        Mapper will use the given bindto."""
+        """bind the given Mapper to the given Engine or Connection.  
+        
+        All subsequent operations involving this Mapper will use the given bindto."""
         self.binds[mapper] = bindto
     def bind_table(self, table, bindto):
-        """binds the given Table to the given Engine or Connection.  All subsequent operations involving this
-        Table will use the given bindto."""
+        """bind the given Table to the given Engine or Connection.  
+        
+        All subsequent operations involving this Table will use the given bindto."""
         self.binds[table] = bindto
     def get_bind(self, mapper):
         """return the Engine or Connection which is used to execute statements on behalf of the given Mapper.
@@ -198,36 +200,6 @@ class Session(object):
                     
     sql = property(_sql)
     
-        
-    def get_id_key(ident, class_, entity_name=None):
-        """return an identity-map key for use in storing/retrieving an item from the identity map.
-
-        ident - a tuple of primary key values corresponding to the object to be stored.  these
-        values should be in the same order as the primary keys of the table 
-
-        class_ - a reference to the object's class
-
-        entity_name - optional string name to further qualify the class
-        """
-        return (class_, tuple(ident), entity_name)
-    get_id_key = staticmethod(get_id_key)
-
-    def get_row_key(row, class_, primary_key, entity_name=None):
-        """return an identity-map key for use in storing/retrieving an item from the identity map.
-
-        row - a sqlalchemy.dbengine.RowProxy instance or other map corresponding result-set
-        column names to their values within a row.
-
-        class_ - a reference to the object's class
-
-        primary_key - a list of column objects that will target the primary key values
-        in the given row.
-        
-        entity_name - optional string name to further qualify the class
-        """
-        return (class_, tuple([row[column] for column in primary_key]), entity_name)
-    get_row_key = staticmethod(get_row_key)
-    
     def flush(self, objects=None):
         """flush all the object modifications present in this session to the database.  
         
@@ -265,8 +237,12 @@ class Session(object):
             raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % repr(obj))
 
     def expire(self, obj):
-        """invalidate the data in the given object and sets them to refresh themselves
-        the next time they are requested."""
+        """mark the given object as expired.
+        
+        this will add an instrumentation to all mapped attributes on the instance such that when
+        an attribute is next accessed, the session will reload all attributes on the instance
+        from the database.
+        """
         self._validate_persistent(obj)
         def exp():
             if self.query(obj.__class__)._get(obj._instance_key, reload=True) is None:
@@ -274,6 +250,7 @@ class Session(object):
         attribute_manager.trigger_history(obj, exp)
 
     def is_expired(self, obj, unexpire=False):
+        """return True if the given object has been marked as expired."""
         ret = attribute_manager.has_trigger(obj)
         if ret and unexpire:
             attribute_manager.untrigger_history(obj)
@@ -290,8 +267,10 @@ class Session(object):
         
     def save(self, object, entity_name=None):
         """
-        Adds a transient (unsaved) instance to this Session.  This operation cascades the "save_or_update" 
-        method to associated instances if the relation is mapped with cascade="save-update".        
+        Add a transient (unsaved) instance to this Session.  
+        
+        This operation cascades the "save_or_update" method to associated instances if the 
+        relation is mapped with cascade="save-update".        
         
         The 'entity_name' keyword argument will further qualify the specific Mapper used to handle this
         instance.
@@ -300,15 +279,21 @@ class Session(object):
         object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
 
     def update(self, object, entity_name=None):
-        """Brings the given detached (saved) instance into this Session.
-        If there is a persistent instance with the same identifier (i.e. a saved instance already associated with this
-        Session), an exception is thrown. 
+        """Bring the given detached (saved) instance into this Session.
+        
+        If there is a persistent instance with the same identifier already associated
+        with this Session, an exception is thrown. 
+
         This operation cascades the "save_or_update" method to associated instances if the relation is mapped 
         with cascade="save-update"."""
         self._update_impl(object, entity_name=entity_name)
         object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
 
     def save_or_update(self, object, entity_name=None):
+        """save or update the given object into this Session.
+        
+        The presence of an '_instance_key' attribute on the instance determines whether to
+        save() or update() the instance."""
         self._save_or_update_impl(object, entity_name=entity_name)
         object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
     
@@ -320,11 +305,16 @@ class Session(object):
             self._update_impl(object, entity_name=entity_name)
         
     def delete(self, object, entity_name=None):
-        #self.uow.register_deleted(object)
+        """mark the given instance as deleted.
+        
+        the delete operation occurs upon flush()."""
         for c in [object] + list(object_mapper(object).cascade_iterator('delete', object)):
             self.uow.register_deleted(c)
 
     def merge(self, object, entity_name=None):
+        """merge the object into a newly loaded or existing instance from this Session.
+        
+        note: this method is currently not completely implemented."""
         instance = None
         for obj in [object] + list(object_mapper(object).cascade_iterator('merge', object)):
             key = getattr(obj, '_instance_key', None)
@@ -430,12 +420,6 @@ class Session(object):
         """deprecated; a synynom for merge()"""
         return self.merge(*args, **kwargs)
 
-def get_id_key(ident, class_, entity_name=None):
-    return Session.get_id_key(ident, class_, entity_name)
-
-def get_row_key(row, class_, primary_key, entity_name=None):
-    return Session.get_row_key(row, class_, primary_key, entity_name)
-
 def object_mapper(obj):
     return sqlalchemy.orm.object_mapper(obj)
 
@@ -453,6 +437,7 @@ attribute_manager = unitofwork.attribute_manager
 _sessions = weakref.WeakValueDictionary()
 
 def object_session(obj):
+    """return the Session to which the given object is bound, or None if none."""
     hashkey = getattr(obj, '_sa_session_id', None)
     if hashkey is not None:
         return _sessions.get(hashkey)
@@ -460,9 +445,3 @@ def object_session(obj):
 
 unitofwork.object_session = object_session
 
-
-def get_session(obj=None):
-    """deprecated"""
-    if obj is not None:
-        return object_session(obj)
-    raise exceptions.InvalidRequestError("get_session() is deprecated, and does not return the thread-local session anymore. Use the SessionContext.mapper_extension or import sqlalchemy.mod.threadlocal to establish a default thread-local context.")
index 6bbeb65f1c5b65fae12cec2f611efe6d34af4392..88d7f6c52e8056c027b9980d8232da684c041649 100644 (file)
@@ -458,10 +458,24 @@ class EagerLoader(AbstractRelationLoader):
             return
         
         try:
-            clauses = self.clauses_by_lead_mapper[selectcontext.mapper]
-            decorated_row = clauses._decorate_row(row)
+            # decorate the row according to the stored AliasedClauses for this eager load,
+            # or look for a user-defined decorator in the SelectContext (which was set up by the contains_eager() option)
+            if selectcontext.attributes.has_key((EagerLoader, self)):
+                # custom row decoration function, placed in the selectcontext by the 
+                # contains_eager() mapper option
+                decorator = selectcontext.attributes[(EagerLoader, self)]
+                if decorator is None:
+                    decorated_row = row
+                else:
+                    decorated_row = decorator(row)
+                print "OK! ROW IS", decorated_row
+            else:
+                # AliasedClauses, keyed to the lead mapper used in the query
+                clauses = self.clauses_by_lead_mapper[selectcontext.mapper]
+                decorated_row = clauses._decorate_row(row)
+                print "OK! DECORATED ROW IS", decorated_row
             # check for identity key
-            identity_key = self.mapper._row_identity_key(decorated_row)
+            identity_key = self.mapper.identity_key_from_row(decorated_row)
         except KeyError:
             # else degrade to a lazy loader
             self.logger.debug("degrade to lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
@@ -513,5 +527,11 @@ class EagerLazyOption(StrategizedOption):
         elif self.lazy is None:
             return NoLoader
 
-
+class RowDecorateOption(PropertyOption):
+    def __init__(self, key, decorator=None):
+        super(RowDecorateOption, self).__init__(key)
+        self.decorator = decorator
+    def process_selection_property(self, context, property):
+        context.attributes[(EagerLoader, property)] = self.decorator
+        
 
index 5c2e21c5fa2468c18bcea3fe88fb885f6943a73f..c4fd92e36136a333edd548b238f9131e0053e37e 100644 (file)
@@ -14,12 +14,11 @@ an "identity map" pattern.  The Unit of Work then maintains lists of objects tha
 dirty, or deleted and provides the capability to flush all those changes at once.
 """
 
-from sqlalchemy import attributes, util, logging
+from sqlalchemy import attributes, util, logging, topological
 import sqlalchemy
 from sqlalchemy.exceptions import *
 import StringIO
 import weakref
-import topological
 import sets
 
 class UOWEventHandler(attributes.AttributeExtension):
@@ -203,8 +202,6 @@ class UOWTransaction(object):
         self.mappers = util.Set()
         self.dependencies = {}
         self.tasks = {}
-        self.__modified = False
-        self.__is_executing = False
         self.logger = logging.instance_logger(self)
         self.echo = uow.echo
         
@@ -229,8 +226,7 @@ class UOWTransaction(object):
         self.mappers.add(mapper)
         task = self.get_task_by_mapper(mapper)
         if postupdate:
-            mod = task.append_postupdate(obj, post_update_cols)
-            if mod: self._mark_modified()
+            task.append_postupdate(obj, post_update_cols)
             return
                 
         # for a cyclical task, things need to be sorted out already,
@@ -239,8 +235,7 @@ class UOWTransaction(object):
         if task.circular:
             return
         
-        mod = task.append(obj, listonly, isdelete=isdelete, **kwargs)
-        if mod: self._mark_modified()
+        task.append(obj, listonly, isdelete=isdelete, **kwargs)
 
     def unregister_object(self, obj):
         #print "UNREGISTER", obj
@@ -248,12 +243,6 @@ class UOWTransaction(object):
         task = self.get_task_by_mapper(mapper)
         if obj in task.objects:
             task.delete(obj)
-            self._mark_modified()
-    
-    def _mark_modified(self):
-        #if self.__is_executing:
-        #    raise "test assertion failed"
-        self.__modified = True
     
         
     def is_deleted(self, obj):
@@ -287,7 +276,6 @@ class UOWTransaction(object):
         dependency = dependency.primary_mapper().base_mapper()
         
         self.dependencies[(mapper, dependency)] = True
-        self._mark_modified()
 
     def register_processor(self, mapper, processor, mapperfrom):
         """called by mapper.PropertyLoader to register itself as a "processor", which
@@ -307,7 +295,6 @@ class UOWTransaction(object):
         targettask = self.get_task_by_mapper(mapperfrom)
         up = UOWDependencyProcessor(processor, targettask)
         task.dependencies.add(up)
-        self._mark_modified()
 
     def execute(self):
         # insure that we have a UOWTask for every mapper that will be involved 
@@ -328,14 +315,7 @@ class UOWTransaction(object):
             if not ret:
                 break
         
-        # flip the execution flag on.  in some test cases
-        # we like to check this flag against any new objects being added, since everything
-        # should be registered by now.  there is a slight exception in the case of 
-        # post_update requests; this should be fixed.
-        self.__is_executing = True
-        
         head = self._sort_dependencies()
-        self.__modified = False
         if self.echo:
             if head is None:
                 self.logger.info("Task dump: None")
@@ -343,8 +323,6 @@ class UOWTransaction(object):
                 self.logger.info("Task dump:\n" + head.dump())
         if head is not None:
             head.execute(self)
-        #if self.__modified and head is not None:
-        #    raise "Assertion failed ! new pre-execute dependency step should eliminate post-execute changes (except post_update stuff)."
         self.logger.info("Execute Complete")
             
     def post_exec(self):
@@ -391,182 +369,25 @@ class UOWTransaction(object):
             mappers.add(base)
         return mappers
         
-        
-class UOWTaskElement(object):
-    """an element within a UOWTask.  corresponds to a single object instance
-    to be saved, deleted, or just part of the transaction as a placeholder for 
-    further dependencies (i.e. 'listonly').
-    in the case of self-referential mappers, may also store a list of childtasks,
-    further UOWTasks containing objects dependent on this element's object instance."""
-    def __init__(self, obj):
-        self.obj = obj
-        self.__listonly = True
-        self.childtasks = []
-        self.__isdelete = False
-        self.__preprocessed = {}
-    def _get_listonly(self):
-        return self.__listonly
-    def _set_listonly(self, value):
-        """set_listonly is a one-way setter, will only go from True to False."""
-        if not value and self.__listonly:
-            self.__listonly = False
-            self.clear_preprocessed()
-    def _get_isdelete(self):
-        return self.__isdelete
-    def _set_isdelete(self, value):
-        if self.__isdelete is not value:
-            self.__isdelete = value
-            self.clear_preprocessed()
-    listonly = property(_get_listonly, _set_listonly)
-    isdelete = property(_get_isdelete, _set_isdelete)
-    
-    def mark_preprocessed(self, processor):
-        """marks this element as "preprocessed" by a particular UOWDependencyProcessor.  preprocessing is the step
-        which sweeps through all the relationships on all the objects in the flush transaction and adds other objects
-        which are also affected,  In some cases it can switch an object from "tosave" to "todelete".  changes to the state 
-        of this UOWTaskElement will reset all "preprocessed" flags, causing it to be preprocessed again.  When all UOWTaskElements
-        have been fully preprocessed by all UOWDependencyProcessors, then the topological sort can be done."""
-        self.__preprocessed[processor] = True
-    def is_preprocessed(self, processor):
-        return self.__preprocessed.get(processor, False)
-    def clear_preprocessed(self):
-        self.__preprocessed.clear()
-    def __repr__(self):
-        return "UOWTaskElement/%d: %s/%d %s" % (id(self), self.obj.__class__.__name__, id(self.obj), (self.listonly and 'listonly' or (self.isdelete and 'delete' or 'save')) )
-
-class UOWDependencyProcessor(object):
-    """in between the saving and deleting of objects, process "dependent" data, such as filling in 
-    a foreign key on a child item from a new primary key, or deleting association rows before a 
-    delete.  This object acts as a proxy to a DependencyProcessor."""
-    def __init__(self, processor, targettask):
-        self.processor = processor
-        self.targettask = targettask
-    def __eq__(self, other):
-        return other.processor is self.processor and other.targettask is self.targettask
-    def __hash__(self):
-        return hash((self.processor, self.targettask))
-        
-    def preexecute(self, trans):
-        """traverses all objects handled by this dependency processor and locates additional objects which should be 
-        part of the transaction, such as those affected deletes, orphans to be deleted, etc. Returns True if any
-        objects were preprocessed, or False if no objects were preprocessed."""
-        def getobj(elem):
-            elem.mark_preprocessed(self)
-            return elem.obj
-        
-        ret = False
-        elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if elem.obj is not None and not elem.is_preprocessed(self)]
-        if len(elements):
-            ret = True
-            self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False)
-
-        elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if elem.obj is not None and not elem.is_preprocessed(self)]
-        if len(elements):
-            ret = True
-            self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True)
-        return ret
-        
-    def execute(self, trans, delete):
-        if not delete:
-            self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.polymorphic_tosave_elements if elem.obj is not None], trans, delete=False)
-        else:            
-            self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.polymorphic_todelete_elements if elem.obj is not None], trans, delete=True)
-
-    def get_object_dependencies(self, obj, trans, passive):
-        return self.processor.get_object_dependencies(obj, trans, passive=passive)
-
-    def whose_dependent_on_who(self, obj, o):        
-        return self.processor.whose_dependent_on_who(obj, o)
-
-    def branch(self, task):
-        return UOWDependencyProcessor(self.processor, task)
-
-class UOWExecutor(object):
-    def execute(self, trans, task, isdelete=None):
-        if isdelete is not True:
-            self.execute_save_steps(trans, task)
-        if isdelete is not False:
-            self.execute_delete_steps(trans, task)
-
-    def save_objects(self, trans, task):
-        task._save_objects(trans)
-
-    def delete_objects(self, trans, task):
-        task._delete_objects(trans)
-    
-    def execute_dependency(self, trans, dep, isdelete):
-        dep.execute(trans, isdelete)
-        
-    def execute_save_steps(self, trans, task):
-        if task.circular is not None:
-            self.execute_save_steps(trans, task.circular)
-        else:
-            self.save_objects(trans, task)
-            self.execute_cyclical_dependencies(trans, task, False)
-            self.execute_per_element_childtasks(trans, task, False)
-            self.execute_dependencies(trans, task, False)
-            self.execute_dependencies(trans, task, True)
-            self.execute_childtasks(trans, task, False)
-                
-    def execute_delete_steps(self, trans, task):    
-        if task.circular is not None:
-            self.execute_delete_steps(trans, task.circular)
-        else:
-            self.execute_cyclical_dependencies(trans, task, True)
-            self.execute_childtasks(trans, task, True)
-            self.execute_per_element_childtasks(trans, task, True)
-            self.delete_objects(trans, task)
-
-    def execute_dependencies(self, trans, task, isdelete=None):
-        alltasks = list(task.polymorphic_tasks())
-        if isdelete is not True:
-            for task in alltasks:
-                for dep in task.dependencies:
-                    self.execute_dependency(trans, dep, False)
-        if isdelete is not False:
-            alltasks.reverse()
-            for task in alltasks:
-                for dep in task.dependencies:
-                    self.execute_dependency(trans, dep, True)
-
-    def execute_childtasks(self, trans, task, isdelete=None):
-        for polytask in task.polymorphic_tasks():
-            for child in polytask.childtasks:
-                self.execute(trans, child, isdelete)
-                
-    def execute_cyclical_dependencies(self, trans, task, isdelete):
-        for polytask in task.polymorphic_tasks():
-            for dep in polytask.cyclical_dependencies:
-                self.execute_dependency(trans, dep, isdelete)
-                
-    def execute_per_element_childtasks(self, trans, task, isdelete):
-        for polytask in task.polymorphic_tasks():
-            for element in polytask.tosave_elements + polytask.todelete_elements:
-                self.execute_element_childtasks(trans, element, isdelete)
-    
-    def execute_element_childtasks(self, trans, element, isdelete):
-        for child in element.childtasks:
-            self.execute(trans, child, isdelete)
-    
 class UOWTask(object):
     """represents the full list of objects that are to be saved/deleted by a specific Mapper."""
     def __init__(self, uowtransaction, mapper, circular_parent=None):
         if not circular_parent:
             uowtransaction.tasks[mapper] = self
-        
+
         # the transaction owning this UOWTask
         self.uowtransaction = uowtransaction
-        
+
         # the Mapper which this UOWTask corresponds to
         self.mapper = mapper
-        
+
         # a dictionary mapping object instances to a corresponding UOWTaskElement.
         # Each UOWTaskElement represents one instance which is to be saved or 
         # deleted by this UOWTask's Mapper.
         # in the case of the row-based "circular sort", the UOWTaskElement may
         # also reference further UOWTasks which are dependent on that UOWTaskElement.
         self.objects = {} #util.OrderedDict()
-        
+
         # a list of UOWDependencyProcessors which are executed after saves and
         # before deletes, to synchronize data to dependent objects
         self.dependencies = util.Set()
@@ -575,7 +396,7 @@ class UOWTask(object):
         # are to be executed after this UOWTask performs saves and post-save
         # dependency processing, and before pre-delete processing and deletes
         self.childtasks = []
-        
+
         # whether this UOWTask is circular, meaning it holds a second
         # UOWTask that contains a special row-based dependency structure.
         self.circular = None
@@ -583,16 +404,16 @@ class UOWTask(object):
         # for a task thats part of that row-based dependency structure, points
         # back to the "public facing" task.
         self.circular_parent = circular_parent
-        
+
         # a list of UOWDependencyProcessors are derived from the main
         # set of dependencies, referencing sub-UOWTasks attached to this
         # one which represent portions of the total list of objects.
         # this is used for the row-based "circular sort"
         self.cyclical_dependencies = util.Set()
-        
+
     def is_empty(self):
         return len(self.objects) == 0 and len(self.dependencies) == 0 and len(self.childtasks) == 0
-    
+
     def append(self, obj, listonly = False, childtask = None, isdelete = False):
         """appends an object to this task, to be either saved or deleted depending on the
         'isdelete' attribute of this UOWTask.  'listonly' indicates that the object should
@@ -614,14 +435,14 @@ class UOWTask(object):
         if isdelete:
             rec.isdelete = True
         return retval
-    
+
     def append_postupdate(self, obj, post_update_cols):
         # postupdates are UPDATED immeditely (for now)
         # convert post_update_cols list to a Set so that __hashcode__ is used to compare columns
         # instead of __eq__
         self.mapper.save_obj([obj], self.uowtransaction, postupdate=True, post_update_cols=util.Set(post_update_cols))
         return True
-            
+
     def delete(self, obj):
         try:
             del self.objects[obj]
@@ -633,11 +454,11 @@ class UOWTask(object):
     def _delete_objects(self, trans):
         for task in self.polymorphic_tasks():
             task.mapper.delete_obj(task.todelete_objects, trans)
-            
+
     def execute(self, trans):
         """executes this UOWTask.  saves objects to be saved, processes all dependencies
         that have been registered, and deletes objects to be deleted. """
-        
+
         UOWExecutor().execute(trans, self)
 
     def polymorphic_tasks(self):
@@ -645,10 +466,10 @@ class UOWTask(object):
         mappers are inheriting descendants of this UOWTask's mapper.  UOWTasks are returned in order
         of their hierarchy to each other, meaning if UOWTask B's mapper inherits from UOWTask A's 
         mapper, then UOWTask B will appear after UOWTask A in the iteration."""
-        
+
         # first us
         yield self
-        
+
         # "circular dependency" tasks aren't polymorphic
         if self.circular_parent is not None:
             return
@@ -662,12 +483,12 @@ class UOWTask(object):
                 else:
                     for t in _tasks_by_mapper(m):
                         yield t
-        
+
         # main yield loop
         for task in _tasks_by_mapper(self.mapper):
             for t in task.polymorphic_tasks():
                 yield t
-                
+
     def contains_object(self, obj, polymorphic=False):
         if polymorphic:
             for task in self.polymorphic_tasks():
@@ -680,13 +501,13 @@ class UOWTask(object):
 
     def is_inserted(self, obj):
         return not hasattr(obj, '_instance_key')
-    
+
     def is_deleted(self, obj):
         try:
             return self.objects[obj].isdelete
         except KeyError:
             return False
-          
+
     def get_elements(self, polymorphic=False):
         if polymorphic:
             for task in self.polymorphic_tasks():
@@ -695,7 +516,7 @@ class UOWTask(object):
         else:
             for rec in self.objects.values():
                 yield rec
-    
+
     polymorphic_tosave_elements = property(lambda self: [rec for rec in self.get_elements(polymorphic=True) if not rec.isdelete])
     polymorphic_todelete_elements = property(lambda self: [rec for rec in self.get_elements(polymorphic=True) if rec.isdelete])
     tosave_elements = property(lambda self: [rec for rec in self.get_elements(polymorphic=False) if not rec.isdelete])
@@ -703,30 +524,30 @@ class UOWTask(object):
     tosave_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=False) if rec.obj is not None and not rec.listonly and rec.isdelete is False])
     todelete_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=False) if rec.obj is not None and not rec.listonly and rec.isdelete is True])
     polymorphic_tosave_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=True) if rec.obj is not None and not rec.listonly and rec.isdelete is False])
-        
+
     def _sort_circular_dependencies(self, trans, cycles):
         """for a single task, creates a hierarchical tree of "subtasks" which associate
         specific dependency actions with individual objects.  This is used for a
         "cyclical" task, or a task where elements
         of its object list contain dependencies on each other.
-        
+
         this is not the normal case; this logic only kicks in when something like 
         a hierarchical tree is being represented."""
         allobjects = []
         for task in cycles:
             allobjects += [e.obj for e in task.get_elements(polymorphic=True)]
         tuples = []
-        
+
         cycles = util.Set(cycles)
-        
+
         #print "BEGIN CIRC SORT-------"
         #print "PRE-CIRC:"
         #print list(cycles) #[0].dump()
-        
+
         # dependency processors that arent part of the cyclical thing
         # get put here
         extradeplist = []
-        
+
         # organizes a set of new UOWTasks that will be assembled into
         # the final tree, for the purposes of holding new UOWDependencyProcessors
         # which process small sub-sections of dependent parent/child operations
@@ -748,7 +569,7 @@ class UOWTask(object):
             proctask = trans.get_task_by_mapper(dep.processor.mapper.primary_mapper().base_mapper(), True)
             targettask = trans.get_task_by_mapper(dep.targettask.mapper.base_mapper(), True)
             return targettask in cycles and (proctask is not None and proctask in cycles)
-            
+
         # organize all original UOWDependencyProcessors by their target task
         deps_by_targettask = {}
         for t in cycles:
@@ -761,14 +582,14 @@ class UOWTask(object):
                         l.append(dep)
 
         object_to_original_task = {}
-        
+
         for t in cycles:
             for task in t.polymorphic_tasks():
                 for taskelement in task.get_elements(polymorphic=False):
                     obj = taskelement.obj
                     object_to_original_task[obj] = task
                     #print "OBJ", repr(obj), "TASK", repr(task)
-                    
+
                     for dep in deps_by_targettask.get(task, []):
                         # is this dependency involved in one of the cycles ?
                         #print "DEP iterate", dep.processor.key, dep.processor.parent, dep.processor.mapper
@@ -778,16 +599,16 @@ class UOWTask(object):
                         #print "DEP", dep.processor.key    
                         (processor, targettask) = (dep.processor, dep.targettask)
                         isdelete = taskelement.isdelete
-                    
+
                         # list of dependent objects from this object
                         childlist = dep.get_object_dependencies(obj, trans, passive=True)
                         if childlist is None:
                             continue
                         # the task corresponding to saving/deleting of those dependent objects
                         childtask = trans.get_task_by_mapper(processor.mapper.primary_mapper())
-                    
+
                         childlist = childlist.added_items() + childlist.unchanged_items() + childlist.deleted_items()
-                        
+
                         for o in childlist:
                             if o is None or not childtask.contains_object(o, polymorphic=True):
                                 continue
@@ -804,7 +625,7 @@ class UOWTask(object):
                                     get_dependency_task(whosdep[0], dep).append(whosdep[1], isdelete=isdelete)
                             else:
                                 get_dependency_task(obj, dep).append(obj, isdelete=isdelete)
-        
+
         #print "TUPLES", tuples
         head = DependencySorter(tuples, allobjects).sort()
         if head is None:
@@ -823,7 +644,7 @@ class UOWTask(object):
                 nexttasks[originating_task] = t
                 parenttask.append(None, listonly=False, isdelete=originating_task.objects[node.item].isdelete, childtask=t)
             t.append(node.item, originating_task.objects[node.item].listonly, isdelete=originating_task.objects[node.item].isdelete)
-                
+
             if dependencies.has_key(node.item):
                 for depprocessor, deptask in dependencies[node.item].iteritems():
                     t.cyclical_dependencies.add(depprocessor.branch(deptask))
@@ -848,8 +669,8 @@ class UOWTask(object):
         import uowdumper
         uowdumper.UOWDumper(self, buf)
         return buf.getvalue()
-        
-        
+
+
     def __repr__(self):
         if self.mapper is not None:
             if self.mapper.__class__.__name__ == 'Mapper':
@@ -859,6 +680,164 @@ class UOWTask(object):
         else:
             name = '(none)'
         return ("UOWTask(%d) Mapper: '%s'" % (id(self), name))
+        
+class UOWTaskElement(object):
+    """an element within a UOWTask.  corresponds to a single object instance
+    to be saved, deleted, or just part of the transaction as a placeholder for 
+    further dependencies (i.e. 'listonly').
+    in the case of self-referential mappers, may also store a list of childtasks,
+    further UOWTasks containing objects dependent on this element's object instance."""
+    def __init__(self, obj):
+        self.obj = obj
+        self.__listonly = True
+        self.childtasks = []
+        self.__isdelete = False
+        self.__preprocessed = {}
+    def _get_listonly(self):
+        return self.__listonly
+    def _set_listonly(self, value):
+        """set_listonly is a one-way setter, will only go from True to False."""
+        if not value and self.__listonly:
+            self.__listonly = False
+            self.clear_preprocessed()
+    def _get_isdelete(self):
+        return self.__isdelete
+    def _set_isdelete(self, value):
+        if self.__isdelete is not value:
+            self.__isdelete = value
+            self.clear_preprocessed()
+    listonly = property(_get_listonly, _set_listonly)
+    isdelete = property(_get_isdelete, _set_isdelete)
+    
+    def mark_preprocessed(self, processor):
+        """marks this element as "preprocessed" by a particular UOWDependencyProcessor.  preprocessing is the step
+        which sweeps through all the relationships on all the objects in the flush transaction and adds other objects
+        which are also affected,  In some cases it can switch an object from "tosave" to "todelete".  changes to the state 
+        of this UOWTaskElement will reset all "preprocessed" flags, causing it to be preprocessed again.  When all UOWTaskElements
+        have been fully preprocessed by all UOWDependencyProcessors, then the topological sort can be done."""
+        self.__preprocessed[processor] = True
+    def is_preprocessed(self, processor):
+        return self.__preprocessed.get(processor, False)
+    def clear_preprocessed(self):
+        self.__preprocessed.clear()
+    def __repr__(self):
+        return "UOWTaskElement/%d: %s/%d %s" % (id(self), self.obj.__class__.__name__, id(self.obj), (self.listonly and 'listonly' or (self.isdelete and 'delete' or 'save')) )
+
+class UOWDependencyProcessor(object):
+    """in between the saving and deleting of objects, process "dependent" data, such as filling in 
+    a foreign key on a child item from a new primary key, or deleting association rows before a 
+    delete.  This object acts as a proxy to a DependencyProcessor."""
+    def __init__(self, processor, targettask):
+        self.processor = processor
+        self.targettask = targettask
+    def __eq__(self, other):
+        return other.processor is self.processor and other.targettask is self.targettask
+    def __hash__(self):
+        return hash((self.processor, self.targettask))
+        
+    def preexecute(self, trans):
+        """traverses all objects handled by this dependency processor and locates additional objects which should be 
+        part of the transaction, such as those affected deletes, orphans to be deleted, etc. Returns True if any
+        objects were preprocessed, or False if no objects were preprocessed."""
+        def getobj(elem):
+            elem.mark_preprocessed(self)
+            return elem.obj
+        
+        ret = False
+        elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if elem.obj is not None and not elem.is_preprocessed(self)]
+        if len(elements):
+            ret = True
+            self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False)
+
+        elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if elem.obj is not None and not elem.is_preprocessed(self)]
+        if len(elements):
+            ret = True
+            self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True)
+        return ret
+        
+    def execute(self, trans, delete):
+        if not delete:
+            self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.polymorphic_tosave_elements if elem.obj is not None], trans, delete=False)
+        else:            
+            self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.polymorphic_todelete_elements if elem.obj is not None], trans, delete=True)
+
+    def get_object_dependencies(self, obj, trans, passive):
+        return self.processor.get_object_dependencies(obj, trans, passive=passive)
+
+    def whose_dependent_on_who(self, obj, o):        
+        return self.processor.whose_dependent_on_who(obj, o)
+
+    def branch(self, task):
+        return UOWDependencyProcessor(self.processor, task)
+
+class UOWExecutor(object):
+    """encapsulates the execution traversal of a UOWTransaction structure."""
+    def execute(self, trans, task, isdelete=None):
+        if isdelete is not True:
+            self.execute_save_steps(trans, task)
+        if isdelete is not False:
+            self.execute_delete_steps(trans, task)
+
+    def save_objects(self, trans, task):
+        task._save_objects(trans)
+
+    def delete_objects(self, trans, task):
+        task._delete_objects(trans)
+    
+    def execute_dependency(self, trans, dep, isdelete):
+        dep.execute(trans, isdelete)
+        
+    def execute_save_steps(self, trans, task):
+        if task.circular is not None:
+            self.execute_save_steps(trans, task.circular)
+        else:
+            self.save_objects(trans, task)
+            self.execute_cyclical_dependencies(trans, task, False)
+            self.execute_per_element_childtasks(trans, task, False)
+            self.execute_dependencies(trans, task, False)
+            self.execute_dependencies(trans, task, True)
+            self.execute_childtasks(trans, task, False)
+                
+    def execute_delete_steps(self, trans, task):    
+        if task.circular is not None:
+            self.execute_delete_steps(trans, task.circular)
+        else:
+            self.execute_cyclical_dependencies(trans, task, True)
+            self.execute_childtasks(trans, task, True)
+            self.execute_per_element_childtasks(trans, task, True)
+            self.delete_objects(trans, task)
+
+    def execute_dependencies(self, trans, task, isdelete=None):
+        alltasks = list(task.polymorphic_tasks())
+        if isdelete is not True:
+            for task in alltasks:
+                for dep in task.dependencies:
+                    self.execute_dependency(trans, dep, False)
+        if isdelete is not False:
+            alltasks.reverse()
+            for task in alltasks:
+                for dep in task.dependencies:
+                    self.execute_dependency(trans, dep, True)
+
+    def execute_childtasks(self, trans, task, isdelete=None):
+        for polytask in task.polymorphic_tasks():
+            for child in polytask.childtasks:
+                self.execute(trans, child, isdelete)
+                
+    def execute_cyclical_dependencies(self, trans, task, isdelete):
+        for polytask in task.polymorphic_tasks():
+            for dep in polytask.cyclical_dependencies:
+                self.execute_dependency(trans, dep, isdelete)
+                
+    def execute_per_element_childtasks(self, trans, task, isdelete):
+        for polytask in task.polymorphic_tasks():
+            for element in polytask.tosave_elements + polytask.todelete_elements:
+                self.execute_element_childtasks(trans, element, isdelete)
+    
+    def execute_element_childtasks(self, trans, element, isdelete):
+        for child in element.childtasks:
+            self.execute(trans, child, isdelete)
+    
 
 class DependencySorter(topological.QueueDependencySorter):
     pass
index 4935b1adda0812aecf5f9e673b9174cf6a356b16..bfbcff5541331b726c9f58ae50a91406301f7083 100644 (file)
@@ -1,6 +1,4 @@
-import sqlalchemy.sql as sql
-import sqlalchemy.schema as schema
-import sqlalchemy.util as util
+from sqlalchemy import sql, util, schema, topological
 
 """utility functions that build upon SQL and Schema constructs"""
 
@@ -36,7 +34,6 @@ class TableCollection(object):
             return sorted
             
     def _do_sort(self):
-        import sqlalchemy.orm.topological
         tuples = []
         class TVisitor(schema.SchemaVisitor):
             def visit_foreign_key(_self, fkey):
@@ -49,7 +46,7 @@ class TableCollection(object):
         vis = TVisitor()        
         for table in self.tables:
             table.accept_schema_visitor(vis)
-        sorter = sqlalchemy.orm.topological.QueueDependencySorter( tuples, self.tables )
+        sorter = topological.QueueDependencySorter( tuples, self.tables )
         head =  sorter.sort()
         sequence = []
         def to_sequence( node, seq=sequence):
similarity index 98%
rename from lib/sqlalchemy/orm/topological.py
rename to lib/sqlalchemy/topological.py
index 8c481b6f11645951ec113dd24800dd8df6e04781..948b3cfeadd117050837303c94bcda7f74bab057 100644 (file)
@@ -36,7 +36,7 @@ import sqlalchemy.util as util
 from sqlalchemy.exceptions import *
 
 class QueueDependencySorter(object):
-    """this is a topological sort from wikipedia.  its very stable.  it creates a straight-line
+    """topological sort adapted from wikipedia's article on the subject.  it creates a straight-line
     list of elements, then a second pass groups non-dependent actions together to build
     more of a tree structure with siblings."""
     class Node:
index 6a26c8b4be97d0dbb63018b14448b57dc391ef9a..a048adbd8788d3b165b945f1ddc9ee21c4da8e87 100644 (file)
@@ -477,6 +477,7 @@ class MapperTest(MapperSuperTest):
             print u[0].orders[1].items[0].keywords[1]
         self.assert_sql_count(db, go, 3)
         sess.clear()
+        print "MARK"
         u = q2.select()
         self.assert_sql_count(db, go, 2)
         
@@ -873,6 +874,19 @@ class EagerTest(MapperSuperTest):
             {'user_id' : 9, 'addresses' : (Address, [])}
             )
 
+    def testcustom(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, lazy=False)
+        })
+        mapper(Address, addresses)
+        
+        selectquery = users.outerjoin(addresses).select(use_labels=True)
+        q = create_session().query(User)
+        
+        l = q.options(contains_eager('addresses')).instances(selectquery.execute())
+#        l = q.instances(selectquery.execute())
+        self.assert_result(l, User, *user_address_result)
+        
     def testorderby_desc(self):
         m = mapper(Address, addresses)