]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- the whole OperationContext/QueryContext/SelectionContext thing greatly scaled back;
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 21 Oct 2007 16:36:34 +0000 (16:36 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 21 Oct 2007 16:36:34 +0000 (16:36 +0000)
all MapperOptions process the Query and that's it, one very simpliied QueryContext object gets passed
around at query.compile() and query.instances() time
- slight optimization to MapperExtension allowing the mapper to check for the presence of an extended method, takes 3000 calls off of masseagerload.py test (only a slight increase in speed though)
- attempting to centralize the notion of a "path" along mappers/properties, need to define what that is better.  heading towards [ticket:777]...

lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py

index cd3dfbec0682c8b9160debb5942adf2a73fbde42..acbb115051205f6a1b36a5832c4bbca469121d32 100644 (file)
@@ -10,7 +10,7 @@ from sqlalchemy.sql import expression
 
 __all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension',
            'MapperProperty', 'PropComparator', 'StrategizedProperty', 
-           'LoaderStack', 'OperationContext', 'MapperOption', 
+           'LoaderStack', 'build_path', 'MapperOption', 
            'ExtensionOption', 'SynonymProperty', 'PropertyOption', 
            'AttributeExtension', 'StrategizedOption', 'LoaderStrategy' ]
 
@@ -493,6 +493,12 @@ class StrategizedProperty(MapperProperty):
         if self.is_primary():
             self.strategy.init_class_attribute()
 
+def build_path(mapper, key, prev=None):
+    if prev:
+        return prev + (mapper.base_mapper, key)
+    else:
+        return (mapper.base_mapper, key)
+        
 class LoaderStack(object):
     """a stack object used during load operations to track the 
     current position among a chain of mappers to eager loaders."""
@@ -521,33 +527,10 @@ class LoaderStack(object):
         
     def __str__(self):
         return "->".join([str(s) for s in self.__stack])
-        
-class OperationContext(object):
-    """Serve as a context during a query construction or instance
-    loading operation.
 
-    Accept ``MapperOption`` objects which may modify its state before proceeding.
-    """
-
-    def __init__(self, mapper, options, attributes=None):
-        self.mapper = mapper
-        self.options = options
-        self.attributes = attributes or {}
-        self.recursion_stack = util.Set()
-        for opt in util.flatten_iterator(options):
-            self.accept_option(opt)
-
-    def accept_option(self, opt):
-        pass
 
 class MapperOption(object):
-    """Describe a modification to an OperationContext or Query."""
-
-    def process_query_context(self, context):
-        pass
-
-    def process_selection_context(self, context):
-        pass
+    """Describe a modification to a Query."""
 
     def process_query(self, query):
         pass
@@ -598,24 +581,18 @@ class PropertyOption(MapperOption):
     def __init__(self, key):
         self.key = key
 
-    def process_query_property(self, context, properties):
-        pass
+    def process_query(self, query):
+        self.process_query_property(query, self._get_properties(query))
 
-    def process_selection_property(self, context, properties):
+    def process_query_property(self, query, properties):
         pass
 
-    def process_query_context(self, context):
-        self.process_query_property(context, self._get_properties(context))
-
-    def process_selection_context(self, context):
-        self.process_selection_property(context, self._get_properties(context))
-
-    def _get_properties(self, context):
+    def _get_properties(self, query):
         try:
             l = self.__prop
         except AttributeError:
             l = []
-            mapper = context.mapper
+            mapper = query.mapper
             for token in self.key.split('.'):
                 prop = mapper.get_property(token, resolve_synonyms=True)
                 l.append(prop)
@@ -649,21 +626,13 @@ class StrategizedOption(PropertyOption):
     def is_chained(self):
         return False
         
-    def process_query_property(self, context, properties):
-        self.logger.debug("applying option to QueryContext, property key '%s'" % self.key)
+    def process_query_property(self, query, properties):
+        self.logger.debug("applying option to Query, property key '%s'" % self.key)
         if self.is_chained():
             for prop in properties:
-                context.attributes[("loaderstrategy", prop)] = self.get_strategy_class()
+                query._attributes[("loaderstrategy", prop)] = self.get_strategy_class()
         else:
-            context.attributes[("loaderstrategy", properties[-1])] = self.get_strategy_class()
-
-    def process_selection_property(self, context, properties):
-        self.logger.debug("applying option to SelectionContext, property key '%s'" % self.key)
-        if self.is_chained():
-            for prop in properties:
-                context.attributes[("loaderstrategy", prop)] = self.get_strategy_class()
-        else:     
-            context.attributes[("loaderstrategy", properties[-1])] = self.get_strategy_class()
+            query._attributes[("loaderstrategy", properties[-1])] = self.get_strategy_class()
 
     def get_strategy_class(self):
         raise NotImplementedError()
index b68b4c8fe951c940d5fa0575decd466577a82b01..8b92e8b4e6926647cbb07cca458380437f9256aa 100644 (file)
@@ -666,7 +666,8 @@ class Mapper(object):
 
         def extra_init(class_, oldinit, instance, args, kwargs):
             self.compile()
-            self.extension.init_instance(self, class_, oldinit, instance, args, kwargs)
+            if 'init_instance' in self.extension.methods:
+                self.extension.init_instance(self, class_, oldinit, instance, args, kwargs)
         
         def on_exception(class_, oldinit, instance, args, kwargs):
             util.warn_exception(self.extension.init_failed, self, class_, oldinit, instance, args, kwargs)
@@ -843,12 +844,14 @@ class Mapper(object):
         Raise ``InvalidRequestError`` if a session cannot be retrieved
         from the extension chain.
         """
+        
+        if 'get_session' in self.extension.methods:
+            s = self.extension.get_session()
+            if s is not EXT_CONTINUE:
+                return s
 
-        s = self.extension.get_session()
-        if s is EXT_CONTINUE:
-            raise exceptions.InvalidRequestError("No contextual Session is established.  Use a MapperExtension that implements get_session or use 'import sqlalchemy.mods.threadlocal' to establish a default thread-local contextual session.")
-        return s
-
+        raise exceptions.InvalidRequestError("No contextual Session is established.  Use a MapperExtension that implements get_session or use 'import sqlalchemy.mods.threadlocal' to establish a default thread-local contextual session.")
+            
     def has_eager(self):
         """Return True if one of the properties attached to this
         Mapper is eager loading.
@@ -969,10 +972,12 @@ class Mapper(object):
             for obj, connection in tups:
                 if not has_identity(obj):
                     for mapper in object_mapper(obj).iterate_to_root():
-                        mapper.extension.before_insert(mapper, connection, obj)
+                        if 'before_insert' in mapper.extension.methods:
+                            mapper.extension.before_insert(mapper, connection, obj)
                 else:
                     for mapper in object_mapper(obj).iterate_to_root():
-                        mapper.extension.before_update(mapper, connection, obj)
+                        if 'before_update' in mapper.extension.methods:
+                            mapper.extension.before_update(mapper, connection, obj)
 
         for obj, connection in tups:
             # detect if we have a "pending" instance (i.e. has no instance_key attached to it),
@@ -1157,10 +1162,12 @@ class Mapper(object):
         if not postupdate:
             for obj, connection in inserted_objects:
                 for mapper in object_mapper(obj).iterate_to_root():
-                    mapper.extension.after_insert(mapper, connection, obj)
+                    if 'after_insert' in mapper.extension.methods:
+                        mapper.extension.after_insert(mapper, connection, obj)
             for obj, connection in updated_objects:
                 for mapper in object_mapper(obj).iterate_to_root():
-                    mapper.extension.after_update(mapper, connection, obj)
+                    if 'after_update' in mapper.extension.methods:
+                        mapper.extension.after_update(mapper, connection, obj)
 
     def _postfetch(self, connection, table, obj, resultproxy, params, value_params):
         """After an ``INSERT`` or ``UPDATE``, assemble newly generated
@@ -1209,7 +1216,8 @@ class Mapper(object):
 
         for (obj, connection) in tups:
             for mapper in object_mapper(obj).iterate_to_root():
-                mapper.extension.before_delete(mapper, connection, obj)
+                if 'before_delete' in mapper.extension.methods:
+                    mapper.extension.before_delete(mapper, connection, obj)
         
         deleted_objects = util.Set()
         table_to_mapper = {}
@@ -1255,7 +1263,8 @@ class Mapper(object):
 
         for obj, connection in deleted_objects:
             for mapper in object_mapper(obj).iterate_to_root():
-                mapper.extension.after_delete(mapper, connection, obj)
+                if 'after_delete' in mapper.extension.methods:
+                    mapper.extension.after_delete(mapper, connection, obj)
 
     def _has_pks(self, table):
         try:
@@ -1355,9 +1364,10 @@ class Mapper(object):
         else:
             extension = self.extension
 
-        ret = extension.translate_row(self, context, row)
-        if ret is not EXT_CONTINUE:
-            row = ret
+        if 'translate_row' in extension.methods:
+            ret = extension.translate_row(self, context, row)
+            if ret is not EXT_CONTINUE:
+                row = ret
 
         if not skip_polymorphic and self.polymorphic_on is not None:
             discriminator = row[self.polymorphic_on]
@@ -1392,10 +1402,10 @@ class Mapper(object):
                 if identitykey not in local_identity_map:
                     local_identity_map[identitykey] = instance
                     isnew = True
-                if extension.populate_instance(self, context, row, instance, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
+                if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
                     self.populate_instance(context, instance, row, instancekey=identitykey, isnew=isnew)
 
-            if extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
+            if 'append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
                 if result is not None:
                     result.append(instance)
             return instance
@@ -1420,9 +1430,13 @@ class Mapper(object):
                     return None
 
             # plugin point
-            instance = extension.create_instance(self, context, row, self.class_)
-            if instance is EXT_CONTINUE:
+            if 'create_instance' in extension.methods:
+                instance = extension.create_instance(self, context, row, self.class_)
+                if instance is EXT_CONTINUE:
+                    instance = attribute_manager.new_instance(self.class_)
+            else:
                 instance = attribute_manager.new_instance(self.class_)
+                
             instance._entity_name = self.entity_name
             if self.__should_log_debug:
                 self.__log_debug("_instance(): created new instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey)))
@@ -1435,9 +1449,9 @@ class Mapper(object):
         # call further mapper properties on the row, to pull further
         # instances from the row and possibly populate this item.
         flags = {'instancekey':identitykey, 'isnew':isnew}
-        if extension.populate_instance(self, context, row, instance, **flags) is EXT_CONTINUE:
+        if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, **flags) is EXT_CONTINUE:
             self.populate_instance(context, instance, row, **flags)
-        if extension.append_result(self, context, row, instance, result, **flags) is EXT_CONTINUE:
+        if 'append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, **flags) is EXT_CONTINUE:
             if result is not None:
                 result.append(instance)
                 
index f6268579f20955b2cd94eb607ae09e44d5c7e163..e5534e22c644dd1f3e7c01571f1205effe5939ea 100644 (file)
@@ -9,10 +9,10 @@ from sqlalchemy.sql import util as sql_util
 from sqlalchemy.sql import expression, visitors
 from sqlalchemy.orm import mapper, object_mapper
 from sqlalchemy.orm import util as mapperutil
-from sqlalchemy.orm.interfaces import OperationContext, LoaderStack
+from sqlalchemy.orm.interfaces import LoaderStack
 import operator
 
-__all__ = ['Query', 'QueryContext', 'SelectionContext']
+__all__ = ['Query', 'QueryContext']
 
 class Query(object):
     """Encapsulates the object-fetching operations provided by Mappers."""
@@ -46,6 +46,8 @@ class Query(object):
         self._populate_existing = False
         self._version_check = False
         self._autoflush = True
+        self._eager_loaders = util.Set([x for x in self.mapper._eager_loaders])
+        self._attributes = {}
         
     def _clone(self):
         q = Query.__new__(Query)
@@ -245,6 +247,9 @@ class Query(object):
         """
         
         q = self._clone()
+        # most MapperOptions write to the '_attributes' dictionary,
+        # so copy that as well
+        q._attributes = q._attributes.copy()
         opts = [o for o in util.flatten_iterator(args)]
         q._with_options = q._with_options + opts
         for opt in opts:
@@ -638,10 +643,9 @@ class Query(object):
 
         session = self.session
 
-        kwargs.setdefault('populate_existing', self._populate_existing)
-        kwargs.setdefault('version_check', self._version_check)
-        
-        context = SelectionContext(self.select_mapper, session, self._extension, with_options=self._with_options, **kwargs)
+        context = kwargs.pop('querycontext', None)
+        if context is None:
+            context = QueryContext(self)
 
         process = []
         mappers_or_columns = tuple(self._entities) + mappers_or_columns
@@ -725,18 +729,6 @@ class Query(object):
         except IndexError:
             return None
 
-    def _should_nest(self, querycontext):
-        """Return True if the given statement options indicate that we
-        should *nest* the generated query as a subquery inside of a
-        larger eager-loading query.  This is used with keywords like
-        distinct, limit and offset and the mapper defines eager loads.
-        """
-
-        return (
-            len(querycontext.eager_loaders) > 0
-            and self._nestable(**querycontext.select_args())
-        )
-
     def _nestable(self, **kwargs):
         """Return true if the given statement options imply it should be nested."""
 
@@ -767,7 +759,7 @@ class Query(object):
         whereclause = self._criterion
 
         context = QueryContext(self)
-        from_obj = context.from_obj
+        from_obj = self._from_obj
 
         alltables = []
         for l in [sql_util.TableFinder(x) for x in from_obj]:
@@ -775,11 +767,11 @@ class Query(object):
 
         if self.table not in alltables:
             from_obj.append(self.table)
-        if self._nestable(**context.select_args()):
-            s = sql.select([self.table], whereclause, from_obj=from_obj, **context.select_args()).alias('getcount').count()
+        if self._nestable(**self._select_args()):
+            s = sql.select([self.table], whereclause, from_obj=from_obj, **self._select_args()).alias('getcount').count()
         else:
             primary_key = self.primary_key_columns
-            s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **context.select_args())
+            s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **self._select_args())
         if self._autoflush and not self._populate_existing:
             self.session._autoflush()
         return self.session.scalar(s, params=self._params, mapper=self.mapper)
@@ -812,12 +804,10 @@ class Query(object):
                 if isinstance(m, mapper.Mapper):
                     table = m.select_table
                     sql_util.ClauseAdapter(m.select_table).traverse(whereclause, stop_on=util.Set([m.select_table]))
+
+        from_obj = self._from_obj
         
-        # get/create query context.  get the ultimate compile arguments
-        # from there
-        order_by = context.order_by
-        from_obj = context.from_obj
-        lockmode = context.lockmode
+        order_by = self._order_by
         if order_by is False:
             order_by = self.mapper.order_by
         if order_by is False:
@@ -825,9 +815,9 @@ class Query(object):
                 order_by = self.table.default_order_by()
 
         try:
-            for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[lockmode]
+            for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[self._lockmode]
         except KeyError:
-            raise exceptions.ArgumentError("Unknown lockmode '%s'" % lockmode)
+            raise exceptions.ArgumentError("Unknown lockmode '%s'" % self._lockmode)
 
         # if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so
         # that we only load the appropriate types
@@ -841,7 +831,7 @@ class Query(object):
         if self.table not in alltables:
             from_obj.append(self.table)
 
-        if self._should_nest(context):
+        if self._eager_loaders and self._nestable(**self._select_args()):
             # if theres an order by, add those columns to the column list
             # of the "rowcount" query we're going to make
             if order_by:
@@ -852,7 +842,7 @@ class Query(object):
             else:
                 cf = []
 
-            s2 = sql.select(self.primary_key_columns + list(cf), whereclause, use_labels=True, from_obj=from_obj, correlate=False, **context.select_args())
+            s2 = sql.select(self.primary_key_columns + list(cf), whereclause, use_labels=True, from_obj=from_obj, correlate=False, **self._select_args())
             if order_by:
                 s2 = s2.order_by(*util.to_list(order_by))
             s3 = s2.alias('tbl_row_count')
@@ -863,7 +853,7 @@ class Query(object):
             if order_by:
                 statement.append_order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by))
         else:
-            statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **context.select_args())
+            statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **self._select_args())
             if order_by:
                 statement.append_order_by(*util.to_list(order_by))
                 
@@ -871,7 +861,7 @@ class Query(object):
             # to use it in "order_by".  ensure they are in the column criterion (particularly oid).
             # TODO: this should be done at the SQL level not the mapper level
             # TODO: need test coverage for this 
-            if context.distinct and order_by:
+            if self._distinct and order_by:
                 [statement.append_column(c) for c in util.to_list(order_by)]
 
         context.statement = statement
@@ -896,10 +886,17 @@ class Query(object):
                 
         return context
 
+    def _select_args(self):
+        """Return a dictionary of attributes that can be applied to a ``sql.Select`` statement.
+        """
+        return {'limit':self._limit, 'offset':self._offset, 'distinct':self._distinct, 'group_by':self._group_by or None}
+
+
     def _get_entity_clauses(self, m):
         """for tuples added via add_entity() or add_column(), attempt to locate
         an AliasedClauses object which should be used to formulate the query as well
         as to process result rows."""
+        
         (m, alias, alias_id) = m
         if alias is not None:
             return alias
@@ -1151,95 +1148,19 @@ for deprecated_method in ['list', 'scalar', 'count_by',
 
 Query.logger = logging.class_logger(Query)
 
-class QueryContext(OperationContext):
-    """Created within the ``Query.compile()`` method to store and
-    share state among all the Mappers and MapperProperty objects used
-    in a query construction.
-    """
-
+class QueryContext(object):
     def __init__(self, query):
         self.query = query
-        self.order_by = query._order_by
-        self.group_by = query._group_by
-        self.from_obj = query._from_obj
-        self.lockmode = query._lockmode
-        self.distinct = query._distinct
-        self.limit = query._limit
-        self.offset = query._offset
-        self.eager_loaders = util.Set([x for x in query.mapper._eager_loaders])
+        self.mapper = query.mapper
+        self.session = query.session
+        self.extension = query._extension
         self.statement = None
-        super(QueryContext, self).__init__(query.mapper, query._with_options)
-
-    def select_args(self):
-        """Return a dictionary of attributes from this
-        ``QueryContext`` that can be applied to a ``sql.Select``
-        statement.
-        """
-        return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct, 'group_by':self.group_by or None}
-
-    def accept_option(self, opt):
-        """Accept a ``MapperOption`` which will process (modify) the
-        state of this ``QueryContext``.
-        """
-
-        opt.process_query_context(self)
-
-
-class SelectionContext(OperationContext):
-    """Created within the ``query.instances()`` method to store and share
-    state among all the Mappers and MapperProperty objects used in a
-    load operation.
-
-    SelectionContext contains these attributes:
-
-    mapper
-      The Mapper which originated the instances() call.
-
-    session
-      The Session that is relevant to the instances call.
-
-    identity_map
-      A dictionary which stores newly created instances that have not
-      yet been added as persistent to the Session.
-
-    attributes
-      A dictionary to store arbitrary data; mappers, strategies, and
-      options all store various state information here in order
-      to communicate with each other and to themselves.
-      
-
-    populate_existing
-      Indicates if its OK to overwrite the attributes of instances
-      that were already in the Session.
-
-    version_check
-      Indicates if mappers that have version_id columns should verify
-      that instances existing already within the Session should have
-      this attribute compared to the freshly loaded value.
-      
-    querycontext
-      the QueryContext, if any, used to generate the executed statement.
-      If present, the attribute dictionary from this Context will be used
-      as the basis for this SelectionContext's attribute dictionary.  This
-      allows query-compile-time operations to send messages to the 
-      result-processing-time operations.
-    """
-
-    def __init__(self, mapper, session, extension, **kwargs):
-        self.populate_existing = kwargs.pop('populate_existing', False)
-        self.version_check = kwargs.pop('version_check', False)
-        querycontext = kwargs.pop('querycontext', None)
-        if querycontext:
-            kwargs['attributes'] = querycontext.attributes
-        self.session = session
-        self.extension = extension
+        self.populate_existing = query._populate_existing
+        self.version_check = query._version_check
         self.identity_map = {}
         self.stack = LoaderStack()
-        super(SelectionContext, self).__init__(mapper, kwargs.pop('with_options', []), **kwargs)
-            
-    def accept_option(self, opt):
-        """Accept a MapperOption which will process (modify) the state
-        of this SelectionContext.
-        """
 
-        opt.process_selection_context(self)
+        self.options = query._with_options
+        self.attributes = query._attributes.copy()
+        
+
index 716a6dbba544383390affda96ffd7778ce73abf0..e2a5be696cd16d4378dbbb2efd58f735c13eb6a7 100644 (file)
@@ -239,11 +239,8 @@ class DeferredOption(StrategizedOption):
 class UndeferGroupOption(MapperOption):
     def __init__(self, group):
         self.group = group
-    def process_query_context(self, context):
-        context.attributes[('undefer', self.group)] = True
-
-    def process_selection_context(self, context):
-        context.attributes[('undefer', self.group)] = True
+    def process_query(self, query):
+        query._attributes[('undefer', self.group)] = True
 
 class AbstractRelationLoader(LoaderStrategy):
     def init(self):
@@ -665,14 +662,13 @@ class EagerLazyOption(StrategizedOption):
     def is_chained(self):
         return not self.lazy and self.chained
         
-    def process_query_property(self, context, properties):
+    def process_query_property(self, query, properties):
         if self.lazy:
-            if properties[-1] in context.eager_loaders:
-                context.eager_loaders.remove(properties[-1])
+            if properties[-1] in query._eager_loaders:
+                query._eager_loaders = query._eager_loaders.difference(util.Set([properties[-1]]))
         else:
-            for prop in properties:
-                context.eager_loaders.add(prop)
-        super(EagerLazyOption, self).process_query_property(context, properties)
+            query._eager_loaders = query._eager_loaders.union(util.Set(properties))
+        super(EagerLazyOption, self).process_query_property(query, properties)
 
     def get_strategy_class(self):
         if self.lazy:
@@ -697,8 +693,8 @@ class FetchModeOption(PropertyOption):
             raise exceptions.ArgumentError("Fetchmode must be one of 'join' or 'select'")
         self.type = type
         
-    def process_selection_property(self, context, properties):
-        context.attributes[('fetchmode', properties[-1])] = self.type
+    def process_query_property(self, query, properties):
+        query.attributes[('fetchmode', properties[-1])] = self.type
         
 class RowDecorateOption(PropertyOption):
     def __init__(self, key, decorator=None, alias=None):
@@ -706,7 +702,7 @@ class RowDecorateOption(PropertyOption):
         self.decorator = decorator
         self.alias = alias
 
-    def process_selection_property(self, context, properties):
+    def process_query_property(self, query, properties):
         if self.alias is not None and self.decorator is None:
             if isinstance(self.alias, basestring):
                 self.alias = properties[-1].target.alias(self.alias)
@@ -716,7 +712,7 @@ class RowDecorateOption(PropertyOption):
                     d[c] = row[self.alias.corresponding_column(c)]
                 return d
             self.decorator = decorate
-        context.attributes[("eager_row_processor", properties[-1])] = self.decorator
+        query._attributes[("eager_row_processor", properties[-1])] = self.decorator
 
 RowDecorateOption.logger = logging.class_logger(RowDecorateOption)
         
index 30ecbdfe854cc1a1f006fc86cd2d107a7b710478..f4294502b1a2fcd98565f43052bb6a1de9c87809 100644 (file)
@@ -7,7 +7,7 @@
 from sqlalchemy import sql, util, exceptions
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.sql import visitors
-from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE
+from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, build_path
 
 all_cascades = util.Set(["delete", "delete-orphan", "all", "merge",
                          "expunge", "save-update", "refresh-expire", "none"])
@@ -112,10 +112,19 @@ class TranslatingDict(dict):
     def setdefault(self, col, value):
         return super(TranslatingDict, self).setdefault(self.__translate_col(col), value)
 
-class ExtensionCarrier(MapperExtension):
+class ExtensionCarrier(object):
+    """stores a collection of MapperExtension objects.
+    
+    allows an extension methods to be called on contained MapperExtensions
+    in the order they were added to this object.  Also includes a 'methods' dictionary
+    accessor which allows for a quick check if a particular method
+    is overridden on any contained MapperExtensions.
+    """
+    
     def __init__(self, _elements=None):
         self.__elements = _elements or []
-
+        self.methods = {}
+        
     def copy(self):
         return ExtensionCarrier(list(self.__elements))
         
@@ -125,43 +134,40 @@ class ExtensionCarrier(MapperExtension):
     def insert(self, extension):
         """Insert a MapperExtension at the beginning of this ExtensionCarrier's list."""
 
-        self.__elements.insert(0, extension)
+        self.__elements.insert(0, self.__inspect(extension))
 
     def append(self, extension):
         """Append a MapperExtension at the end of this ExtensionCarrier's list."""
 
-        self.__elements.append(extension)
+        self.__elements.append(self.__inspect(extension))
 
-    def _create_do(funcname):
-        def _do(self, *args, **kwargs):
+    def __inspect(self, extension):
+        for meth in MapperExtension.__dict__.keys():
+            if meth not in self.methods and hasattr(extension, meth) and getattr(extension, meth) is not getattr(MapperExtension, meth):
+                self.methods[meth] = self.__create_do(meth)
+        return extension
+           
+    def __create_do(self, funcname):
+        def _do(*args, **kwargs):
             for elem in self.__elements:
                 ret = getattr(elem, funcname)(*args, **kwargs)
                 if ret is not EXT_CONTINUE:
                     return ret
             else:
                 return EXT_CONTINUE
-        return _do
 
-    instrument_class = _create_do('instrument_class')
-    init_instance = _create_do('init_instance')
-    init_failed = _create_do('init_failed')
-    dispose_class = _create_do('dispose_class')
-    get_session = _create_do('get_session')
-    load = _create_do('load')
-    get = _create_do('get')
-    get_by = _create_do('get_by')
-    select_by = _create_do('select_by')
-    select = _create_do('select')
-    translate_row = _create_do('translate_row')
-    create_instance = _create_do('create_instance')
-    append_result = _create_do('append_result')
-    populate_instance = _create_do('populate_instance')
-    before_insert = _create_do('before_insert')
-    before_update = _create_do('before_update')
-    after_update = _create_do('after_update')
-    after_insert = _create_do('after_insert')
-    before_delete = _create_do('before_delete')
-    after_delete = _create_do('after_delete')
+        try:
+            _do.__name__ = funcname
+        except:
+            # cant set __name__ in py 2.3 
+            pass
+        return _do
+    
+    def _pass(self, *args, **kwargs):
+        return EXT_CONTINUE
+        
+    def __getattr__(self, key):
+        return self.methods.get(key, self._pass)
 
 class BinaryVisitor(visitors.ClauseVisitor):
     def __init__(self, func):
@@ -262,9 +268,9 @@ class PropertyAliasedClauses(AliasedClauses):
             
         self.parentclauses = parentclauses
         if parentclauses is not None:
-            self.path = parentclauses.path + (prop.parent, prop.key)
+            self.path = build_path(prop.parent, prop.key, parentclauses.path)
         else:
-            self.path = (prop.parent, prop.key)
+            self.path = build_path(prop.parent, prop.key)
 
         self.prop = prop