]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [feature] The of_type() construct on attributes
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 20 Jun 2012 23:28:29 +0000 (19:28 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 20 Jun 2012 23:28:29 +0000 (19:28 -0400)
now accepts aliased() class constructs as well
as with_polymorphic constructs, and works with
query.join(), any(), has(), and also
eager loaders subqueryload(), joinedload(),
contains_eager()
[ticket:2438] [ticket:1106]
- a rewrite of the query path system to use an
object based approach for more succinct usage.  the system
has been designed carefully to not add an excessive method overhead.
- [feature] select() features a correlate_except()
method, auto correlates all selectables except those
passed.   Is needed here for the updated any()/has()
functionality.
- remove some old cruft from LoaderStrategy, init(),debug_callable()
- use a namedtuple for _extended_entity_info.  This method should
become standard within the orm internals
- some tweaks to the memory profile tests, number of runs can
be customized to work around pysqlite's very annoying behavior
- try to simplify PropertyOption._get_paths(), rename to _process_paths(),
returns a single list now.  overall works more completely as was needed
for of_type() functionality

22 files changed:
CHANGES
doc/build/orm/inheritance.rst
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/state.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/compat.py
test/aaa_profiling/test_memusage.py
test/orm/inheritance/_poly_fixtures.py
test/orm/inheritance/test_polymorphic_rel.py
test/orm/test_merge.py
test/orm/test_pickled.py
test/orm/test_query.py
test/orm/test_subquery_relations.py
test/perf/orm2010.py

diff --git a/CHANGES b/CHANGES
index 987bc2dc28ec8d34b8db6d7aed766a9933214ddf..77d769fb0922dccf04615cea243c50d7a9b6aa0d 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -49,6 +49,14 @@ CHANGES
     of a join in place of the "of_type()" modifier.
     [ticket:2333]
 
+  - [feature] The of_type() construct on attributes
+    now accepts aliased() class constructs as well
+    as with_polymorphic constructs, and works with
+    query.join(), any(), has(), and also
+    eager loaders subqueryload(), joinedload(),
+    contains_eager()
+    [ticket:2438] [ticket:1106]
+
   - [feature] The "deferred declarative 
     reflection" system has been moved into the 
     declarative extension itself, using the
@@ -296,6 +304,10 @@ CHANGES
     that aren't in the target table is now an exception.
     [ticket:2415]
 
+  - [feature] select() features a correlate_except()
+    method, auto correlates all selectables except those
+    passed.
+
   - [bug] All of UniqueConstraint, ForeignKeyConstraint,
     CheckConstraint, and PrimaryKeyConstraint will
     attach themselves to their parent table automatically
index ba0745065d02f9fb220898b027f6c0016fab676a..c0185bec125b3c2d2de36e802cf2f741e07d2b51 100644 (file)
@@ -423,6 +423,20 @@ Using ``aliased=True`` instead renders it more like::
 
     FROM x JOIN (SELECT * FROM y JOIN z ON <onclause>) AS anon_1 ON <onclause>
 
+The above join can also be expressed more succinctly by combining ``of_type()``
+with the polymorphic construct::
+
+    manager_and_engineer = with_polymorphic(
+                                Employee, [Manager, Engineer], 
+                                aliased=True)
+
+    session.query(Company).\
+        join(Company.employees.of_type(manager_and_engineer)).\
+        filter(
+            or_(manager_and_engineer.Engineer.engineer_info=='someinfo', 
+                manager_and_engineer.Manager.manager_data=='somedata')
+        )
+
 The ``any()`` and ``has()`` operators also can be used with
 :func:`~sqlalchemy.orm.interfaces.PropComparator.of_type` when the embedded
 criterion is in terms of a subclass::
@@ -448,6 +462,28 @@ The EXISTS subquery above selects from the join of ``employees`` to
 ``engineers``, and also specifies criterion which correlates the EXISTS
 subselect back to the parent ``companies`` table.
 
+.. versionadded:: 0.8
+   :func:`~sqlalchemy.orm.interfaces.PropComparator.of_type` accepts
+   :func:`.orm.aliased` and :func:`.orm.with_polymorphic` constructs in conjunction
+   with :meth:`.Query.join`, ``any()`` and ``has()``.
+
+Eager Loading of Specific Subtypes
+++++++++++++++++++++++++++++++++++
+
+The :func:`.joinedload` and :func:`.subqueryload` options also support
+paths which make use of :func:`~sqlalchemy.orm.interfaces.PropComparator.of_type`.
+Below we load ``Company`` rows while eagerly loading related ``Engineer`` 
+objects, querying the ``employee`` and ``engineer`` tables simultaneously::
+
+    session.query(Company).\
+        options(subqueryload_all(Company.employees.of_type(Engineer), 
+                        Engineer.machines))
+
+.. versionadded:: 0.8
+    :func:`.joinedload` and :func:`.subqueryload` support
+    paths that are qualified with 
+    :func:`~sqlalchemy.orm.interfaces.PropComparator.of_type`.
+
 Single Table Inheritance
 ------------------------
 
index 7c5955c5db455a18bd775b7d6d672485d1f37b95..1750bc9f89d63a189530af648221bf6cf4cea1f3 100644 (file)
@@ -112,7 +112,8 @@ __all__ = (
     'synonym',
     'undefer',
     'undefer_group',
-    'validates'
+    'validates',
+    'with_polymorphic'
     )
 
 
index 55e0291b5033257b46dda516361b3e97ee028e8a..e71752ab52b1c1d047558d3469741cd2faf18f31 100644 (file)
@@ -103,12 +103,14 @@ class QueryableAttribute(interfaces.PropComparator):
     """Base class for class-bound attributes. """
 
     def __init__(self, class_, key, impl=None, 
-                        comparator=None, parententity=None):
+                        comparator=None, parententity=None,
+                        of_type=None):
         self.class_ = class_
         self.key = key
         self.impl = impl
         self.comparator = comparator
         self.parententity = parententity
+        self._of_type = of_type
 
         manager = manager_of_class(class_)
         # manager is None in the case of AliasedClass
@@ -137,6 +139,15 @@ class QueryableAttribute(interfaces.PropComparator):
     def __clause_element__(self):
         return self.comparator.__clause_element__()
 
+    def of_type(self, cls):
+        return QueryableAttribute(
+                    self.class_,
+                    self.key,
+                    self.impl,
+                    self.comparator.of_type(cls),
+                    self.parententity,
+                    of_type=cls)
+
     def label(self, name):
         return self.__clause_element__().label(name)
 
index bda48cbb130f5d96c3b45d28cd0a36a3ba21028e..8d185e9f3b94cfd8496a034c75a3fba2c6d3c128 100644 (file)
@@ -42,7 +42,6 @@ __all__ = (
     'SessionExtension',
     'StrategizedOption',
     'StrategizedProperty',
-    'build_path',
     )
 
 EXT_CONTINUE = util.symbol('EXT_CONTINUE')
@@ -77,7 +76,7 @@ class MapperProperty(object):
 
     """
 
-    def setup(self, context, entity, path, reduced_path, adapter, **kwargs):
+    def setup(self, context, entity, path, adapter, **kwargs):
         """Called by Query for the purposes of constructing a SQL statement.
 
         Each MapperProperty associated with the target mapper processes the
@@ -87,7 +86,7 @@ class MapperProperty(object):
 
         pass
 
-    def create_row_processor(self, context, path, reduced_path, 
+    def create_row_processor(self, context, path, 
                                             mapper, row, adapter):
         """Return a 3-tuple consisting of three row processing functions.
 
@@ -112,7 +111,7 @@ class MapperProperty(object):
     def set_parent(self, parent, init):
         self.parent = parent
 
-    def instrument_class(self, mapper):
+    def instrument_class(self, mapper):  # pragma: no-coverage
         raise NotImplementedError()
 
     _compile_started = False
@@ -308,15 +307,23 @@ class StrategizedProperty(MapperProperty):
 
     strategy_wildcard_key = None
 
-    def _get_context_strategy(self, context, reduced_path):
-        key = ('loaderstrategy', reduced_path)
+    @util.memoized_property
+    def _wildcard_path(self):
+        if self.strategy_wildcard_key:
+            return ('loaderstrategy', (self.strategy_wildcard_key,))
+        else:
+            return None
+
+    def _get_context_strategy(self, context, path):
+        # this is essentially performance inlining.
+        key = ('loaderstrategy', path.reduced_path + (self.key,))
         cls = None
         if key in context.attributes:
             cls = context.attributes[key]
-        elif self.strategy_wildcard_key:
-            key = ('loaderstrategy', (self.strategy_wildcard_key,))
-            if key in context.attributes:
-                cls = context.attributes[key]
+        else:
+            wc_key = self._wildcard_path
+            if wc_key and wc_key in context.attributes:
+                cls = context.attributes[wc_key]
 
         if cls:
             try:
@@ -335,15 +342,15 @@ class StrategizedProperty(MapperProperty):
         self._strategies[cls] = strategy = cls(self)
         return strategy
 
-    def setup(self, context, entity, path, reduced_path, adapter, **kwargs):
-        self._get_context_strategy(context, reduced_path + (self.key,)).\
+    def setup(self, context, entity, path, adapter, **kwargs):
+        self._get_context_strategy(context, path).\
                     setup_query(context, entity, path, 
-                                    reduced_path, adapter, **kwargs)
+                                    adapter, **kwargs)
 
-    def create_row_processor(self, context, path, reduced_path, mapper, row, adapter):
-        return self._get_context_strategy(context, reduced_path + (self.key,)).\
+    def create_row_processor(self, context, path, mapper, row, adapter):
+        return self._get_context_strategy(context, path).\
                     create_row_processor(context, path, 
-                                    reduced_path, mapper, row, adapter)
+                                    mapper, row, adapter)
 
     def do_init(self):
         self._strategies = {}
@@ -354,30 +361,6 @@ class StrategizedProperty(MapperProperty):
             not mapper.class_manager._attr_has_impl(self.key):
             self.strategy.init_class_attribute(mapper)
 
-def build_path(entity, key, prev=None):
-    if prev:
-        return prev + (entity, key)
-    else:
-        return (entity, key)
-
-def serialize_path(path):
-    if path is None:
-        return None
-
-    return zip(
-        [m.class_ for m in [path[i] for i in range(0, len(path), 2)]], 
-        [path[i] for i in range(1, len(path), 2)] + [None]
-    )
-
-def deserialize_path(path):
-    if path is None:
-        return None
-
-    p = tuple(chain(*[(mapperutil.class_mapper(cls), key) for cls, key in path]))
-    if p and p[-1] is None:
-        p = p[0:-1]
-    return p
-
 class MapperOption(object):
     """Describe a modification to a Query."""
 
@@ -414,11 +397,11 @@ class PropertyOption(MapperOption):
         self._process(query, False)
 
     def _process(self, query, raiseerr):
-        paths, mappers = self._get_paths(query, raiseerr)
+        paths = self._process_paths(query, raiseerr)
         if paths:
-            self.process_query_property(query, paths, mappers)
+            self.process_query_property(query, paths)
 
-    def process_query_property(self, query, paths, mappers):
+    def process_query_property(self, query, paths):
         pass
 
     def __getstate__(self):
@@ -450,8 +433,7 @@ class PropertyOption(MapperOption):
             searchfor = mapperutil._class_to_mapper(mapper)
             isa = True
         for ent in query._mapper_entities:
-            if searchfor is ent.path_entity or isa \
-                and searchfor.common_parent(ent.path_entity):
+            if ent.corresponds_to(searchfor):
                 return ent
         else:
             if raiseerr:
@@ -488,15 +470,21 @@ class PropertyOption(MapperOption):
             else:
                 return None
 
-    def _get_paths(self, query, raiseerr):
-        path = None
+    def _process_paths(self, query, raiseerr):
+        """reconcile the 'key' for this PropertyOption with
+        the current path and entities of the query.
+        
+        Return a list of affected paths.
+        
+        """
+        path = mapperutil.PathRegistry.root
         entity = None
-        l = []
-        mappers = []
+        paths = []
+        no_result = []
 
         # _current_path implies we're in a 
         # secondary load with an existing path
-        current_path = list(query._current_path)
+        current_path = list(query._current_path.path)
 
         tokens = deque(self.key)
         while tokens:
@@ -504,7 +492,7 @@ class PropertyOption(MapperOption):
             if isinstance(token, basestring):
                 # wildcard token
                 if token.endswith(':*'):
-                    return [(token,)], []
+                    return [path.token(token)]
                 sub_tokens = token.split(".", 1)
                 token = sub_tokens[0]
                 tokens.extendleft(sub_tokens[1:])
@@ -516,7 +504,7 @@ class PropertyOption(MapperOption):
                         current_path = current_path[2:]
                         continue
                     else:
-                        return [], []
+                        return no_result
 
                 if not entity:
                     entity = self._find_entity_basestring(
@@ -524,10 +512,10 @@ class PropertyOption(MapperOption):
                                         token, 
                                         raiseerr)
                     if entity is None:
-                        return [], []
-                    path_element = entity.path_entity
+                        return no_result
+                    path_element = entity.entity_zero
                     mapper = entity.mapper
-                mappers.append(mapper)
+
                 if hasattr(mapper.class_, token):
                     prop = getattr(mapper.class_, token).property
                 else:
@@ -538,7 +526,7 @@ class PropertyOption(MapperOption):
                                 token, mapper)
                         )
                     else:
-                        return [], []
+                        return no_result
             elif isinstance(token, PropComparator):
                 prop = token.property
 
@@ -550,7 +538,7 @@ class PropertyOption(MapperOption):
                         current_path = current_path[2:]
                         continue
                     else:
-                        return [], []
+                        return no_result
 
                 if not entity:
                     entity = self._find_entity_prop_comparator(
@@ -559,10 +547,9 @@ class PropertyOption(MapperOption):
                                             token.parententity, 
                                             raiseerr)
                     if not entity:
-                        return [], []
-                    path_element = entity.path_entity
+                        return no_result
+                    path_element = entity.entity_zero
                     mapper = entity.mapper
-                mappers.append(prop.parent)
             else:
                 raise sa_exc.ArgumentError(
                         "mapper option expects "
@@ -572,11 +559,20 @@ class PropertyOption(MapperOption):
                 raise sa_exc.ArgumentError("Attribute '%s' does not "
                             "link from element '%s'" % (token, path_element))
 
-            path = build_path(path_element, prop.key, path)
+            path = path[path_element][prop.key]
+
+            paths.append(path)
 
-            l.append(path)
             if getattr(token, '_of_type', None):
-                path_element = mapper = token._of_type
+                ac = token._of_type
+                ext_info = mapperutil._extended_entity_info(ac)
+                path_element = mapper = ext_info.mapper
+                if not ext_info.is_aliased_class:
+                    ac = mapperutil.with_polymorphic(
+                                ext_info.mapper.base_mapper, 
+                                ext_info.mapper, aliased=True)
+                    ext_info = mapperutil._extended_entity_info(ac)
+                path.set(query, "path_with_polymorphic", ext_info)
             else:
                 path_element = mapper = getattr(prop, 'mapper', None)
                 if mapper is None and tokens:
@@ -590,9 +586,9 @@ class PropertyOption(MapperOption):
             # ran out of tokens before 
             # current_path was exhausted.
             assert not tokens
-            return [], []
+            return no_result
 
-        return l, mappers
+        return paths
 
 class StrategizedOption(PropertyOption):
     """A MapperOption that affects which LoaderStrategy will be used
@@ -601,40 +597,25 @@ class StrategizedOption(PropertyOption):
 
     chained = False
 
-    def process_query_property(self, query, paths, mappers):
-
-        # _get_context_strategy may receive the path in terms of a base
-        # mapper - e.g.  options(eagerload_all(Company.employees,
-        # Engineer.machines)) in the polymorphic tests leads to
-        # "(Person, 'machines')" in the path due to the mechanics of how
-        # the eager strategy builds up the path
-
+    def process_query_property(self, query, paths):
+        strategy = self.get_strategy_class()
         if self.chained:
             for path in paths:
-                query._attributes[('loaderstrategy',
-                                  _reduce_path(path))] = \
-                    self.get_strategy_class()
+                path.set(
+                    query,
+                    "loaderstrategy",
+                    strategy
+                )
         else:
-            query._attributes[('loaderstrategy',
-                              _reduce_path(paths[-1]))] = \
-                self.get_strategy_class()
+            paths[-1].set(
+                query,
+                "loaderstrategy",
+                strategy
+            )
 
     def get_strategy_class(self):
         raise NotImplementedError()
 
-def _reduce_path(path):
-    """Convert a (mapper, path) path to use base mappers.
-
-    This is used to allow more open ended selection of loader strategies, i.e.
-    Mapper -> prop1 -> Subclass -> prop2, where Subclass is a sub-mapper
-    of the mapper referenced by Mapper.prop1.
-
-    """
-    return tuple([i % 2 != 0 and 
-                    element or 
-                    getattr(element, 'base_mapper', element) 
-                    for i, element in enumerate(path)])
-
 class LoaderStrategy(object):
     """Describe the loading behavior of a StrategizedProperty object.
 
@@ -663,22 +644,14 @@ class LoaderStrategy(object):
         self.is_class_level = False
         self.parent = self.parent_property.parent
         self.key = self.parent_property.key
-        # TODO: there's no particular reason we need
-        # the separate .init() method at this point.
-        # It's possible someone has written their
-        # own LS object.
-        self.init()
-
-    def init(self):
-        raise NotImplementedError("LoaderStrategy")
 
     def init_class_attribute(self, mapper):
         pass
 
-    def setup_query(self, context, entity, path, reduced_path, adapter, **kwargs):
+    def setup_query(self, context, entity, path, adapter, **kwargs):
         pass
 
-    def create_row_processor(self, context, path, reduced_path, mapper, 
+    def create_row_processor(self, context, path, mapper, 
                                 row, adapter):
         """Return row processing functions which fulfill the contract
         specified by MapperProperty.create_row_processor.
@@ -691,16 +664,6 @@ class LoaderStrategy(object):
     def __str__(self):
         return str(self.parent_property)
 
-    def debug_callable(self, fn, logger, announcement, logfn):
-        if announcement:
-            logger.debug(announcement)
-        if logfn:
-            def call(*args, **kwargs):
-                logger.debug(logfn(*args, **kwargs))
-                return fn(*args, **kwargs)
-            return call
-        else:
-            return fn
 
 class InstrumentationManager(object):
     """User-defined class instrumentation extension.
index 2ec30f0bad9cd24450e14f62e500a8dcb68faca1..789f29c738660e0743080fc89c23e10371e22ba6 100644 (file)
@@ -28,7 +28,8 @@ from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, \
                                 PropComparator
 
 from sqlalchemy.orm.util import _INSTRUMENTOR, _class_to_mapper, \
-     _state_mapper, class_mapper, instance_str, state_str
+     _state_mapper, class_mapper, instance_str, state_str,\
+     PathRegistry
 
 import sys
 sessionlib = util.importlater("sqlalchemy.orm", "session")
@@ -432,6 +433,10 @@ class Mapper(object):
 
     dispatch = event.dispatcher(events.MapperEvents)
 
+    @util.memoized_property
+    def _sa_path_registry(self):
+        return PathRegistry.per_mapper(self)
+
     def _configure_inheritance(self):
         """Configure settings related to inherting and/or inherited mappers
         being present."""
@@ -1302,13 +1307,12 @@ class Mapper(object):
 
         return mappers
 
-    def _selectable_from_mappers(self, mappers):
+    def _selectable_from_mappers(self, mappers, innerjoin):
         """given a list of mappers (assumed to be within this mapper's
         inheritance hierarchy), construct an outerjoin amongst those mapper's
         mapped tables.
 
         """
-
         from_obj = self.mapped_table
         for m in mappers:
             if m is self:
@@ -1318,7 +1322,11 @@ class Mapper(object):
                         "'with_polymorphic()' requires 'selectable' argument "
                         "when concrete-inheriting mappers are used.")
             elif not m.single:
-                from_obj = from_obj.outerjoin(m.local_table,
+                if innerjoin:
+                    from_obj = from_obj.join(m.local_table,
+                                                m.inherit_condition)
+                else:
+                    from_obj = from_obj.outerjoin(m.local_table,
                                                 m.inherit_condition)
 
         return from_obj
@@ -1350,9 +1358,11 @@ class Mapper(object):
             return selectable
         else:
             return self._selectable_from_mappers(
-                            self._mappers_from_spec(spec, selectable))
+                            self._mappers_from_spec(spec, selectable),
+                            False)
 
-    def _with_polymorphic_args(self, spec=None, selectable=False):
+    def _with_polymorphic_args(self, spec=None, selectable=False, 
+                                innerjoin=False):
         if self.with_polymorphic:
             if not spec:
                 spec = self.with_polymorphic[0]
@@ -1364,7 +1374,8 @@ class Mapper(object):
         if selectable is not None:
             return mappers, selectable
         else:
-            return mappers, self._selectable_from_mappers(mappers)
+            return mappers, self._selectable_from_mappers(mappers, 
+                                innerjoin)
 
     @_memoized_configured_property
     def _polymorphic_properties(self):
@@ -1926,7 +1937,7 @@ class Mapper(object):
         return result
 
 
-    def _instance_processor(self, context, path, reduced_path, adapter, 
+    def _instance_processor(self, context, path, adapter, 
                                 polymorphic_from=None, 
                                 only_load_props=None, refresh_state=None,
                                 polymorphic_discriminator=None):
@@ -1951,7 +1962,7 @@ class Mapper(object):
                 polymorphic_on = self.polymorphic_on
             polymorphic_instances = util.PopulateDict(
                                         self._configure_subclass_mapper(
-                                                context, path, reduced_path, adapter)
+                                                context, path, adapter)
                                         )
 
         version_id_col = self.version_id_col
@@ -1968,7 +1979,9 @@ class Mapper(object):
         new_populators = []
         existing_populators = []
         eager_populators = []
-        load_path = context.query._current_path + path
+        load_path = context.query._current_path + path \
+                    if context.query._current_path.path \
+                    else path
 
         def populate_state(state, dict_, row, isnew, only_load_props):
             if isnew:
@@ -1978,7 +1991,7 @@ class Mapper(object):
                     state.load_path = load_path
 
             if not new_populators:
-                self._populators(context, path, reduced_path, row, adapter,
+                self._populators(context, path, row, adapter,
                                 new_populators,
                                 existing_populators,
                                 eager_populators
@@ -2015,7 +2028,7 @@ class Mapper(object):
 
         def _instance(row, result):
             if not new_populators and invoke_all_eagers:
-                self._populators(context, path, reduced_path, row, adapter,
+                self._populators(context, path, row, adapter,
                                 new_populators,
                                 existing_populators,
                                 eager_populators
@@ -2191,16 +2204,17 @@ class Mapper(object):
             return instance
         return _instance
 
-    def _populators(self, context, path, reduced_path, row, adapter,
+    def _populators(self, context, path, row, adapter,
             new_populators, existing_populators, eager_populators):
-        """Produce a collection of attribute level row processor callables."""
+        """Produce a collection of attribute level row processor 
+        callables."""
 
         delayed_populators = []
-        pops = (new_populators, existing_populators, delayed_populators, eager_populators)
+        pops = (new_populators, existing_populators, delayed_populators, 
+                            eager_populators)
         for prop in self._props.itervalues():
             for i, pop in enumerate(prop.create_row_processor(
-                                        context, path, 
-                                        reduced_path,
+                                        context, path,
                                         self, row, adapter)):
                 if pop is not None:
                     pops[i].append((prop.key, pop))
@@ -2208,7 +2222,7 @@ class Mapper(object):
         if delayed_populators:
             new_populators.extend(delayed_populators)
 
-    def _configure_subclass_mapper(self, context, path, reduced_path, adapter):
+    def _configure_subclass_mapper(self, context, path, adapter):
         """Produce a mapper level row processor callable factory for mappers
         inheriting this one."""
 
@@ -2223,18 +2237,17 @@ class Mapper(object):
                 return None
 
             # replace the tip of the path info with the subclass mapper 
-            # being used. that way accurate "load_path" info is available 
-            # for options invoked during deferred loads.
-            # we lose AliasedClass path elements this way, but currently,
-            # those are not needed at this stage.
-
-            # this asserts to true
-            #assert mapper.isa(_class_to_mapper(path[-1]))
-
-            return mapper._instance_processor(context, path[0:-1] + (mapper,), 
-                                                    reduced_path[0:-1] + (mapper.base_mapper,),
-                                                    adapter,
-                                                    polymorphic_from=self)
+            # being used, that way accurate "load_path" info is available 
+            # for options invoked during deferred loads, e.g.
+            # query(Person).options(defer(Engineer.machines, Machine.name)).
+            # for AliasedClass paths, disregard this step (new in 0.8).
+            return mapper._instance_processor(
+                                context, 
+                                path.parent[mapper] 
+                                    if not path.is_aliased_class 
+                                    else path, 
+                                adapter,
+                                polymorphic_from=self)
         return configure_subclass_mapper
 
 log.class_logger(Mapper)
index 9a2d057545929c43640882684bb09750020c660b..5634a9c5fe8b8c5a86ad1117774c102a2a8595c1 100644 (file)
@@ -18,7 +18,8 @@ from sqlalchemy.sql import operators, expression, visitors
 from sqlalchemy.orm import attributes, dependency, mapper, \
     object_mapper, strategies, configure_mappers, relationships
 from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, \
-    _orm_annotate, _orm_deannotate, _orm_full_deannotate
+    _orm_annotate, _orm_deannotate, _orm_full_deannotate,\
+    _entity_info
 
 from sqlalchemy.orm.interfaces import MANYTOMANY, MANYTOONE, \
     MapperProperty, ONETOMANY, PropComparator, StrategizedProperty
@@ -305,7 +306,7 @@ class RelationshipProperty(StrategizedProperty):
             self.mapper = mapper
             self.adapter = adapter
             if of_type:
-                self._of_type = _class_to_mapper(of_type)
+                self._of_type = of_type
 
         def adapted(self, adapter):
             """Return a copy of this PropComparator which will use the
@@ -318,7 +319,7 @@ class RelationshipProperty(StrategizedProperty):
                                   getattr(self, '_of_type', None),
                                   adapter)
 
-        @property
+        @util.memoized_property
         def parententity(self):
             return self.property.parent
 
@@ -406,9 +407,8 @@ class RelationshipProperty(StrategizedProperty):
 
         def _criterion_exists(self, criterion=None, **kwargs):
             if getattr(self, '_of_type', None):
-                target_mapper = self._of_type
-                to_selectable = target_mapper._with_polymorphic_selectable
-                if self.property._is_self_referential:
+                target_mapper, to_selectable, is_aliased_class = _entity_info(self._of_type)
+                if self.property._is_self_referential and not is_aliased_class:
                     to_selectable = to_selectable.alias()
 
                 single_crit = target_mapper._single_table_criterion
@@ -418,6 +418,7 @@ class RelationshipProperty(StrategizedProperty):
                     else:
                         criterion = single_crit
             else:
+                is_aliased_class = False
                 to_selectable = None
 
             if self.adapter:
@@ -445,8 +446,7 @@ class RelationshipProperty(StrategizedProperty):
             else:
                 j = _orm_annotate(pj, exclude=self.property.remote_side)
 
-            # MARKMARK
-            if criterion is not None and target_adapter:
+            if criterion is not None and target_adapter and not is_aliased_class:
                 # limit this adapter to annotated only?
                 criterion = target_adapter.traverse(criterion)
 
@@ -460,8 +460,10 @@ class RelationshipProperty(StrategizedProperty):
 
             crit = j & criterion
 
-            return sql.exists([1], crit, from_obj=dest).\
-                            correlate(source._annotate({'_orm_adapt':True}))
+            ex = sql.exists([1], crit, from_obj=dest).correlate_except(dest)
+            if secondary is not None:
+                ex = ex.correlate_except(secondary)
+            return ex
 
         def any(self, criterion=None, **kwargs):
             """Produce an expression that tests a collection against
index 2c06063bccfae83e8ca84f2643e0604efcf04eb8..987a77ba964c90c404298fb6592cd6d033ba1cdf 100644 (file)
@@ -32,7 +32,7 @@ from sqlalchemy.orm import (
     )
 from sqlalchemy.orm.util import (
     AliasedClass, ORMAdapter, _entity_descriptor, _entity_info,
-    _extended_entity_info,
+    _extended_entity_info, PathRegistry,
     _is_aliased_class, _is_mapped_class, _orm_columns, _orm_selectable,
     join as orm_join,with_parent, _attr_as_key, aliased
     )
@@ -53,6 +53,8 @@ def _generative(*assertions):
         return self
     return generate
 
+_path_registry = PathRegistry.root
+
 class Query(object):
     """ORM-level SQL construction object.
 
@@ -88,7 +90,6 @@ class Query(object):
     _invoke_all_eagers = True
     _version_check = False
     _autoflush = True
-    _current_path = ()
     _only_load_props = None
     _refresh_state = None
     _from_obj = ()
@@ -105,6 +106,8 @@ class Query(object):
     _with_hints = ()
     _enable_single_crit = True
 
+    _current_path = _path_registry
+
     def __init__(self, entities, session=None):
         self.session = session
         self._polymorphic_adapters = {}
@@ -125,28 +128,25 @@ class Query(object):
         for ent in entities:
             for entity in ent.entities:
                 if entity not in d:
-                    mapper, selectable, \
-                    is_aliased_class, with_polymorphic_mappers, \
-                    with_polymorphic_discriminator = \
-                                            _extended_entity_info(entity)
-                    if not is_aliased_class and mapper.with_polymorphic:
-                        if mapper.mapped_table not in \
+                    ext_info = _extended_entity_info(entity)
+                    if not ext_info.is_aliased_class and ext_info.mapper.with_polymorphic:
+                        if ext_info.mapper.mapped_table not in \
                                             self._polymorphic_adapters:
-                            self._mapper_loads_polymorphically_with(mapper, 
+                            self._mapper_loads_polymorphically_with(ext_info.mapper, 
                                 sql_util.ColumnAdapter(
-                                            selectable, 
-                                            mapper._equivalent_columns))
+                                            ext_info.selectable, 
+                                            ext_info.mapper._equivalent_columns))
                         aliased_adapter = None
-                    elif is_aliased_class:
+                    elif ext_info.is_aliased_class:
                         aliased_adapter = sql_util.ColumnAdapter(
-                                            selectable, 
-                                            mapper._equivalent_columns)
+                                            ext_info.selectable, 
+                                            ext_info.mapper._equivalent_columns)
                     else:
                         aliased_adapter = None
 
-                    d[entity] = (mapper, aliased_adapter, selectable, 
-                                        is_aliased_class, with_polymorphic_mappers,
-                                        with_polymorphic_discriminator)
+                    d[entity] = (ext_info.mapper, aliased_adapter, ext_info.selectable, 
+                                        ext_info.is_aliased_class, ext_info.with_polymorphic_mappers,
+                                        ext_info.with_polymorphic_discriminator)
                 ent.setup_entity(entity, *d[entity])
 
     def _mapper_loads_polymorphically_with(self, mapper, adapter):
@@ -1251,7 +1251,7 @@ class Query(object):
     def having(self, criterion):
         """apply a HAVING criterion to the query and return the
         newly resulting :class:`.Query`.
-        
+
         :meth:`having` is used in conjunction with :meth:`group_by`.
  
         HAVING criterion makes it possible to use filters on aggregate
@@ -1940,7 +1940,6 @@ class Query(object):
             raise sa_exc.InvalidRequestError(
                     "Could not find a FROM clause to join from.  "
                     "Tried joining to %s, but got: %s" % (right, ae))
-
         self._from_obj = self._from_obj + (clause,)
 
     def _reset_joinpoint(self):
@@ -2872,7 +2871,7 @@ class _MapperEntity(_QueryEntity):
         query._entities.append(self)
 
         self.entities = [entity]
-        self.entity_zero = self.expr = entity
+        self.expr = entity
 
     def setup_entity(self, entity, mapper, aliased_adapter, 
                         from_obj, is_aliased_class, 
@@ -2885,16 +2884,12 @@ class _MapperEntity(_QueryEntity):
         self._with_polymorphic = with_polymorphic
         self._polymorphic_discriminator = with_polymorphic_discriminator
         if is_aliased_class:
-            self.path_entity = self.entity_zero = entity
-            self._path = (entity,)
+            self.entity_zero = entity
             self._label_name = self.entity_zero._sa_label_name
-            self._reduced_path = (self.path_entity, )
         else:
-            self.path_entity = mapper
-            self._path = (mapper,)
-            self._reduced_path = (mapper.base_mapper, )
             self.entity_zero = mapper
             self._label_name = self.mapper.class_.__name__
+        self.path = self.entity_zero._sa_path_registry
 
     def set_with_polymorphic(self, query, cls_or_mappers, 
                                 selectable, polymorphic_on):
@@ -2929,10 +2924,13 @@ class _MapperEntity(_QueryEntity):
         return self.entity_zero
 
     def corresponds_to(self, entity):
-        if _is_aliased_class(entity) or self.is_aliased_class:
-            return entity is self.path_entity
+        entity_info = _extended_entity_info(entity)
+        if entity_info.is_aliased_class or self.is_aliased_class:
+            return entity is self.entity_zero \
+                or \
+                entity in self._with_polymorphic
         else:
-            return entity.common_parent(self.path_entity)
+            return entity.common_parent(self.entity_zero)
 
     def adapt_to_selectable(self, query, sel):
         query._entities.append(self)
@@ -2976,8 +2974,7 @@ class _MapperEntity(_QueryEntity):
         if self.primary_entity:
             _instance = self.mapper._instance_processor(
                                 context, 
-                                self._path,
-                                self._reduced_path,
+                                self.path,
                                 adapter,
                                 only_load_props=query._only_load_props,
                                 refresh_state=context.refresh_state,
@@ -2987,8 +2984,7 @@ class _MapperEntity(_QueryEntity):
         else:
             _instance = self.mapper._instance_processor(
                                 context, 
-                                self._path,
-                                self._reduced_path,
+                                self.path,
                                 adapter,
                                 polymorphic_discriminator=
                                     self._polymorphic_discriminator)
@@ -3024,8 +3020,7 @@ class _MapperEntity(_QueryEntity):
             value.setup(
                 context,
                 self,
-                self._path,
-                self._reduced_path,
+                self.path,
                 adapter,
                 only_load_props=query._only_load_props,
                 column_collection=context.primary_columns
@@ -3211,7 +3206,8 @@ class QueryContext(object):
         self.create_eager_joins = []
         self.propagate_options = set(o for o in query._with_options if
                                         o.propagate_to_loaders)
-        self.attributes = query._attributes.copy()
+        self.attributes = self._attributes = query._attributes.copy()
+
 
 class AliasOption(interfaces.MapperOption):
 
index 9b0f7538f155f2c745fd23141578234890bf022e..720554483b5ee5cea3488a9c520e36c44672df80 100644 (file)
@@ -24,6 +24,7 @@ from sqlalchemy.orm.attributes import PASSIVE_NO_RESULT, \
 mapperlib = util.importlater("sqlalchemy.orm", "mapperlib")
 sessionlib = util.importlater("sqlalchemy.orm", "session")
 
+
 class InstanceState(object):
     """tracks state information at the instance level."""
 
@@ -177,7 +178,7 @@ class InstanceState(object):
             ) if k in self.__dict__ 
         )
         if self.load_path:
-            d['load_path'] = interfaces.serialize_path(self.load_path)
+            d['load_path'] = self.load_path.serialize()
 
         self.manager.dispatch.pickle(self, d)
 
@@ -222,7 +223,8 @@ class InstanceState(object):
         ])
 
         if 'load_path' in state:
-            self.load_path = interfaces.deserialize_path(state['load_path'])
+            self.load_path = orm_util.PathRegistry.\
+                                deserialize(state['load_path'])
 
         # setup _sa_instance_state ahead of time so that 
         # unpickle events can access the object normally.
index d0f8962be7923fc1f470538c1f9790edbc49ab84..131ced0c9dd2a5ed3f8f7b54110aaaa6e8351082 100644 (file)
@@ -73,7 +73,8 @@ def _register_attribute(strategy, mapper, useobject,
                 compare_function=compare_function, 
                 useobject=useobject,
                 extension=attribute_ext, 
-                trackparent=useobject and (prop.single_parent or prop.direction is interfaces.ONETOMANY), 
+                trackparent=useobject and (prop.single_parent 
+                                or prop.direction is interfaces.ONETOMANY), 
                 typecallable=typecallable,
                 callable_=callable_, 
                 active_history=active_history,
@@ -92,27 +93,29 @@ class UninstrumentedColumnLoader(LoaderStrategy):
     if the argument is against the with_polymorphic selectable.
 
     """
-    def init(self):
+    def __init__(self, parent):
+        super(UninstrumentedColumnLoader, self).__init__(parent)
         self.columns = self.parent_property.columns
 
-    def setup_query(self, context, entity, path, reduced_path, adapter, 
+    def setup_query(self, context, entity, path, adapter, 
                             column_collection=None, **kwargs):
         for c in self.columns:
             if adapter:
                 c = adapter.columns[c]
             column_collection.append(c)
 
-    def create_row_processor(self, context, path, reduced_path, mapper, row, adapter):
+    def create_row_processor(self, context, path, mapper, row, adapter):
         return None, None, None
 
 class ColumnLoader(LoaderStrategy):
     """Provide loading behavior for a :class:`.ColumnProperty`."""
 
-    def init(self):
+    def __init__(self, parent):
+        super(ColumnLoader, self).__init__(parent)
         self.columns = self.parent_property.columns
         self.is_composite = hasattr(self.parent_property, 'composite_class')
 
-    def setup_query(self, context, entity, path, reduced_path, 
+    def setup_query(self, context, entity, path, 
                             adapter, column_collection, **kwargs):
         for c in self.columns:
             if adapter:
@@ -131,7 +134,7 @@ class ColumnLoader(LoaderStrategy):
             active_history = active_history
        )
 
-    def create_row_processor(self, context, path, reduced_path, 
+    def create_row_processor(self, context, path, 
                                             mapper, row, adapter):
         key = self.key
         # look through list of columns represented here
@@ -153,7 +156,15 @@ log.class_logger(ColumnLoader)
 class DeferredColumnLoader(LoaderStrategy):
     """Provide loading behavior for a deferred :class:`.ColumnProperty`."""
 
-    def create_row_processor(self, context, path, reduced_path, mapper, row, adapter):
+    def __init__(self, parent):
+        super(DeferredColumnLoader, self).__init__(parent)
+        if hasattr(self.parent_property, 'composite_class'):
+            raise NotImplementedError("Deferred loading for composite "
+                                    "types not implemented yet")
+        self.columns = self.parent_property.columns
+        self.group = self.parent_property.group
+
+    def create_row_processor(self, context, path, mapper, row, adapter):
         col = self.columns[0]
         if adapter:
             col = adapter.columns[col]
@@ -162,7 +173,7 @@ class DeferredColumnLoader(LoaderStrategy):
         if col in row:
             return self.parent_property._get_strategy(ColumnLoader).\
                         create_row_processor(
-                                context, path, reduced_path, mapper, row, adapter)
+                                context, path, mapper, row, adapter)
 
         elif not self.is_class_level:
             def set_deferred_for_local_state(state, dict_, row):
@@ -175,13 +186,6 @@ class DeferredColumnLoader(LoaderStrategy):
                 state.reset(dict_, key)
             return reset_col_for_deferred, None, None
 
-    def init(self):
-        if hasattr(self.parent_property, 'composite_class'):
-            raise NotImplementedError("Deferred loading for composite "
-                                    "types not implemented yet")
-        self.columns = self.parent_property.columns
-        self.group = self.parent_property.group
-
     def init_class_attribute(self, mapper):
         self.is_class_level = True
 
@@ -191,7 +195,7 @@ class DeferredColumnLoader(LoaderStrategy):
              expire_missing=False
         )
 
-    def setup_query(self, context, entity, path, reduced_path, adapter, 
+    def setup_query(self, context, entity, path, adapter, 
                                 only_load_props=None, **kwargs):
         if (
                 self.group is not None and 
@@ -199,7 +203,7 @@ class DeferredColumnLoader(LoaderStrategy):
             ) or (only_load_props and self.key in only_load_props):
             self.parent_property._get_strategy(ColumnLoader).\
                             setup_query(context, entity,
-                                        path, reduced_path, adapter, **kwargs)
+                                        path, adapter, **kwargs)
 
     def _load_for_state(self, state, passive):
         if not state.key:
@@ -276,12 +280,13 @@ class UndeferGroupOption(MapperOption):
         self.group = group
 
     def process_query(self, query):
-        query._attributes[('undefer', self.group)] = True
+        query._attributes[("undefer", self.group)] = True
 
 class AbstractRelationshipLoader(LoaderStrategy):
     """LoaderStratgies which deal with related objects."""
 
-    def init(self):
+    def __init__(self, parent):
+        super(AbstractRelationshipLoader, self).__init__(parent)
         self.mapper = self.parent_property.mapper
         self.target = self.parent_property.target
         self.uselist = self.parent_property.uselist
@@ -301,7 +306,7 @@ class NoLoader(AbstractRelationshipLoader):
             typecallable = self.parent_property.collection_class,
         )
 
-    def create_row_processor(self, context, path, reduced_path, mapper, row, adapter):
+    def create_row_processor(self, context, path, mapper, row, adapter):
         def invoke_no_load(state, dict_, row):
             state.initialize(self.key)
         return invoke_no_load, None, None
@@ -314,8 +319,8 @@ class LazyLoader(AbstractRelationshipLoader):
     
     """
 
-    def init(self):
-        super(LazyLoader, self).init()
+    def __init__(self, parent):
+        super(LazyLoader, self).__init__(parent)
         join_condition = self.parent_property._join_condition
         self._lazywhere, \
         self._bind_to_col, \
@@ -533,7 +538,7 @@ class LazyLoader(AbstractRelationshipLoader):
             q = q.autoflush(False)
 
         if state.load_path:
-            q = q._with_current_path(state.load_path + (self.key,))
+            q = q._with_current_path(state.load_path[self.key])
 
         if state.load_options:
             q = q._conditional_options(*state.load_options)
@@ -578,7 +583,7 @@ class LazyLoader(AbstractRelationshipLoader):
                 return None
 
 
-    def create_row_processor(self, context, path, reduced_path, 
+    def create_row_processor(self, context, path, 
                                     mapper, row, adapter):
         key = self.key
         if not self.is_class_level:
@@ -633,11 +638,11 @@ class ImmediateLoader(AbstractRelationshipLoader):
                 init_class_attribute(mapper)
 
     def setup_query(self, context, entity, 
-                        path, reduced_path, adapter, column_collection=None,
+                        path, adapter, column_collection=None,
                         parentmapper=None, **kwargs):
         pass
 
-    def create_row_processor(self, context, path, reduced_path, 
+    def create_row_processor(self, context, path, 
                                 mapper, row, adapter):
         def load_immediate(state, dict_, row):
             state.get_impl(self.key).get(state, dict_)
@@ -645,8 +650,8 @@ class ImmediateLoader(AbstractRelationshipLoader):
         return None, None, load_immediate
 
 class SubqueryLoader(AbstractRelationshipLoader):
-    def init(self):
-        super(SubqueryLoader, self).init()
+    def __init__(self, parent):
+        super(SubqueryLoader, self).__init__(parent)
         self.join_depth = self.parent_property.join_depth
 
     def init_class_attribute(self, mapper):
@@ -655,31 +660,36 @@ class SubqueryLoader(AbstractRelationshipLoader):
                 init_class_attribute(mapper)
 
     def setup_query(self, context, entity, 
-                        path, reduced_path, adapter, 
+                        path, adapter, 
                         column_collection=None,
                         parentmapper=None, **kwargs):
 
         if not context.query._enable_eagerloads:
             return
 
-        path = path + (self.key, )
-        reduced_path = reduced_path + (self.key, )
+        path = path[self.key]
 
         # build up a path indicating the path from the leftmost
         # entity to the thing we're subquery loading.
-        subq_path = context.attributes.get(('subquery_path', None), ())
+        with_poly_info = path.get(context, "path_with_polymorphic", None)
+        if with_poly_info is not None:
+            effective_entity = with_poly_info.entity
+        else:
+            effective_entity = self.mapper
+
+        subq_path = context.attributes.get(('subquery_path', None), 
+                                mapperutil.PathRegistry.root)
 
         subq_path = subq_path + path
 
-        # join-depth / recursion check
-        if ("loaderstrategy", reduced_path) not in context.attributes:
+        # if not via query option, check for 
+        # a cycle
+        if not path.contains(context, "loaderstrategy"):
             if self.join_depth:
-                if len(path) / 2 > self.join_depth:
-                    return
-            else:
-                if self.mapper.base_mapper in \
-                    interfaces._reduce_path(subq_path):
+                if path.length / 2 > self.join_depth:
                     return
+            elif subq_path.contains_mapper(self.mapper):
+                return
 
         subq_mapper, leftmost_mapper, leftmost_attr = \
                 self._get_leftmost(subq_path)
@@ -692,7 +702,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
         # produce a subquery from it.
         left_alias = self._generate_from_original_query(
                             orig_query, leftmost_mapper,
-                            leftmost_attr, subq_path
+                            leftmost_attr
         )
 
         # generate another Query that will join the 
@@ -700,7 +710,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
         # basically doing a longhand
         # "from_self()".  (from_self() itself not quite industrial
         # strength enough for all contingencies...but very close)
-        q = orig_query.session.query(self.mapper)
+        q = orig_query.session.query(effective_entity)
         q._attributes = {
             ("orig_query", SubqueryLoader): orig_query,
             ('subquery_path', None) : subq_path
@@ -712,16 +722,18 @@ class SubqueryLoader(AbstractRelationshipLoader):
         q = q.order_by(*local_attr)
         q = q.add_columns(*local_attr)
 
-        q = self._apply_joins(q, to_join, left_alias, parent_alias)
+        q = self._apply_joins(q, to_join, left_alias, 
+                            parent_alias, effective_entity)
 
-        q = self._setup_options(q, subq_path, orig_query)
+        q = self._setup_options(q, subq_path, orig_query, effective_entity)
         q = self._setup_outermost_orderby(q)
 
         # add new query to attributes to be picked up 
         # by create_row_processor
-        context.attributes[('subquery', reduced_path)] = q
+        path.set(context, "subquery", q)
 
     def _get_leftmost(self, subq_path):
+        subq_path = subq_path.path
         subq_mapper = mapperutil._class_to_mapper(subq_path[0])
 
         # determine attributes of the leftmost mapper
@@ -743,7 +755,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
 
     def _generate_from_original_query(self,
             orig_query, leftmost_mapper,
-            leftmost_attr, subq_path
+            leftmost_attr
     ):
         # reformat the original query
         # to look only for significant columns
@@ -769,6 +781,8 @@ class SubqueryLoader(AbstractRelationshipLoader):
 
 
     def _prep_for_joins(self, left_alias, subq_path):
+        subq_path = subq_path.path
+
         # figure out what's being joined.  a.k.a. the fun part
         to_join = [
                     (subq_path[i], subq_path[i+1]) 
@@ -778,11 +792,14 @@ class SubqueryLoader(AbstractRelationshipLoader):
         # determine the immediate parent class we are joining from,
         # which needs to be aliased.
 
+        if len(to_join) > 1:
+            ext = mapperutil._extended_entity_info(subq_path[-2])
+
         if len(to_join) < 2:
             # in the case of a one level eager load, this is the
             # leftmost "left_alias".
             parent_alias = left_alias
-        elif subq_path[-2].isa(self.parent):
+        elif ext.mapper.isa(self.parent):
             # In the case of multiple levels, retrieve
             # it from subq_path[-2]. This is the same as self.parent 
             # in the vast majority of cases, and [ticket:2014] 
@@ -800,10 +817,10 @@ class SubqueryLoader(AbstractRelationshipLoader):
             getattr(parent_alias, self.parent._columntoproperty[c].key)
             for c in local_cols
         ]
-
         return to_join, local_attr, parent_alias
 
-    def _apply_joins(self, q, to_join, left_alias, parent_alias):
+    def _apply_joins(self, q, to_join, left_alias, parent_alias, 
+                    effective_entity):
         for i, (mapper, key) in enumerate(to_join):
 
             # we need to use query.join() as opposed to
@@ -816,11 +833,18 @@ class SubqueryLoader(AbstractRelationshipLoader):
             first = i == 0
             middle = i < len(to_join) - 1
             second_to_last = i == len(to_join) - 2
+            last = i == len(to_join) - 1
 
             if first:
                 attr = getattr(left_alias, key)
+                if last and effective_entity is not self.mapper:
+                    attr = attr.of_type(effective_entity)
             else:
-                attr = key
+                if last and effective_entity is not self.mapper:
+                    attr = getattr(parent_alias, key).\
+                                    of_type(effective_entity)
+                else:
+                    attr = key
 
             if second_to_last:
                 q = q.join(parent_alias, attr, from_joinpoint=True)
@@ -828,13 +852,14 @@ class SubqueryLoader(AbstractRelationshipLoader):
                 q = q.join(attr, aliased=middle, from_joinpoint=True)
         return q
 
-    def _setup_options(self, q, subq_path, orig_query):
+    def _setup_options(self, q, subq_path, orig_query, effective_entity):
         # propagate loader options etc. to the new query.
         # these will fire relative to subq_path.
         q = q._with_current_path(subq_path)
         q = q._conditional_options(*orig_query._with_options)
         if orig_query._populate_existing: 
             q._populate_existing = orig_query._populate_existing
+
         return q
 
     def _setup_outermost_orderby(self, q):
@@ -855,7 +880,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
             q = q.order_by(*eager_order_by)
         return q
 
-    def create_row_processor(self, context, path, reduced_path, 
+    def create_row_processor(self, context, path,
                                     mapper, row, adapter):
         if not self.parent.class_manager[self.key].impl.supports_population:
             raise sa_exc.InvalidRequestError(
@@ -863,27 +888,26 @@ class SubqueryLoader(AbstractRelationshipLoader):
                         "population - eager loading cannot be applied." % 
                         self)
 
-        reduced_path = reduced_path + (self.key,)
+        path = path[self.key]
 
-        if ('subquery', reduced_path) not in context.attributes:
+        subq = path.get(context, 'subquery')
+        if subq is None:
             return None, None, None
 
         local_cols = self.parent_property.local_columns
 
-        q = context.attributes[('subquery', reduced_path)]
-
         # cache the loaded collections in the context
         # so that inheriting mappers don't re-load when they
         # call upon create_row_processor again
-        if ('collections', reduced_path) in context.attributes:
-            collections = context.attributes[('collections', reduced_path)]
-        else:
-            collections = context.attributes[('collections', reduced_path)] = dict(
+        collections = path.get(context, "collections")
+        if collections is None:
+            collections = dict(
                     (k, [v[0] for v in v]) 
                     for k, v in itertools.groupby(
-                        q, 
+                        subq, 
                         lambda x:x[1:]
                     ))
+            path.set(context, 'collections', collections)
 
         if adapter:
             local_cols = [adapter.columns[c] for c in local_cols]
@@ -929,98 +953,114 @@ class JoinedLoader(AbstractRelationshipLoader):
     using joined eager loading.
     
     """
-    def init(self):
-        super(JoinedLoader, self).init()
+    def __init__(self, parent):
+        super(JoinedLoader, self).__init__(parent)
         self.join_depth = self.parent_property.join_depth
 
     def init_class_attribute(self, mapper):
         self.parent_property.\
             _get_strategy(LazyLoader).init_class_attribute(mapper)
 
-    def setup_query(self, context, entity, path, reduced_path, adapter, \
+    def setup_query(self, context, entity, path, adapter, \
                                 column_collection=None, parentmapper=None,
                                 allow_innerjoin=True,
                                 **kwargs):
         """Add a left outer join to the statement thats being constructed."""
 
-
         if not context.query._enable_eagerloads:
             return
 
-        path = path + (self.key,)
-        reduced_path = reduced_path + (self.key,)
+        path = path[self.key]
+
+        with_polymorphic = None
 
-        if ("user_defined_eager_row_processor", reduced_path) in\
-                context.attributes:
+        user_defined_adapter = path.get(context, 
+                                "user_defined_eager_row_processor", 
+                                False)
+        if user_defined_adapter is not False:
             clauses, adapter, add_to_collection = \
                 self._get_user_defined_adapter(
-                    context, entity, reduced_path, adapter
+                    context, entity, path, adapter,
+                    user_defined_adapter
                 )
         else:
-            # check for join_depth or basic recursion,
-            # if the current path was not explicitly stated as 
-            # a desired "loaderstrategy" (i.e. via query.options())
-            if ("loaderstrategy", reduced_path) not in context.attributes:
+            # if not via query option, check for 
+            # a cycle
+            if not path.contains(context, "loaderstrategy"):
                 if self.join_depth:
-                    if len(path) / 2 > self.join_depth:
-                        return
-                else:
-                    if self.mapper.base_mapper in reduced_path:
+                    if path.length / 2 > self.join_depth:
                         return
+                elif path.contains_mapper(self.mapper):
+                    return
 
             clauses, adapter, add_to_collection, \
                 allow_innerjoin = self._generate_row_adapter(
-                    context, entity, path, reduced_path, adapter,
+                    context, entity, path, adapter,
                     column_collection, parentmapper, allow_innerjoin
                 )
 
-        path += (self.mapper,)
-        reduced_path += (self.mapper.base_mapper,)
+        with_poly_info = path.get(
+            context, 
+            "path_with_polymorphic",
+            None
+        )
+        if with_poly_info is not None:
+            with_polymorphic = with_poly_info.with_polymorphic_mappers
+        else:
+            with_polymorphic = None
 
-        for value in self.mapper._polymorphic_properties:
+        path = path[self.mapper]
+        for value in self.mapper._iterate_polymorphic_properties(
+                                mappers=with_polymorphic):
             value.setup(
                 context, 
                 entity, 
                 path, 
-                reduced_path,
                 clauses, 
                 parentmapper=self.mapper, 
                 column_collection=add_to_collection,
                 allow_innerjoin=allow_innerjoin)
 
     def _get_user_defined_adapter(self, context, entity, 
-                                reduced_path, adapter):
-            clauses = context.attributes[
-                                ("user_defined_eager_row_processor",
-                                reduced_path)]
+                                path, adapter, user_defined_adapter):
 
             adapter = entity._get_entity_clauses(context.query, context)
-            if adapter and clauses:
-                context.attributes[
-                            ("user_defined_eager_row_processor",
-                            reduced_path)] = clauses = clauses.wrap(adapter)
+            if adapter and user_defined_adapter:
+                user_defined_adapter = user_defined_adapter.wrap(adapter)
+                path.set(context, "user_defined_eager_row_processor", 
+                                        user_defined_adapter)
             elif adapter:
-                context.attributes[
-                            ("user_defined_eager_row_processor",
-                            reduced_path)] = clauses = adapter
+                user_defined_adapter = adapter
+                path.set(context, "user_defined_eager_row_processor", 
+                                        user_defined_adapter)
 
             add_to_collection = context.primary_columns
-            return clauses, adapter, add_to_collection
+            return user_defined_adapter, adapter, add_to_collection
 
     def _generate_row_adapter(self, 
-        context, entity, path, reduced_path, adapter,
+        context, entity, path, adapter,
         column_collection, parentmapper, allow_innerjoin
     ):
+        with_poly_info = path.get(
+            context, 
+            "path_with_polymorphic", 
+            None
+        )
+        if with_poly_info:
+            to_adapt = with_poly_info.entity
+        else:
+            to_adapt = mapperutil.AliasedClass(self.mapper)
         clauses = mapperutil.ORMAdapter(
-                    mapperutil.AliasedClass(self.mapper)
+                    to_adapt
                     equivalents=self.mapper._equivalent_columns,
                     adapt_required=True)
+        assert clauses.aliased_class is not None
 
         if self.parent_property.direction != interfaces.MANYTOONE:
             context.multi_row_eager_loaders = True
 
-        innerjoin = allow_innerjoin and context.attributes.get(
-                            ("eager_join_type", path)
+        innerjoin = allow_innerjoin and path.get(context,
+                            "eager_join_type"
                             self.parent_property.innerjoin)
         if not innerjoin:
             # if this is an outer join, all eager joins from
@@ -1034,9 +1074,7 @@ class JoinedLoader(AbstractRelationshipLoader):
         )
 
         add_to_collection = context.secondary_columns
-        context.attributes[
-                            ("eager_row_processor", reduced_path)
-                          ] = clauses
+        path.set(context, "eager_row_processor", clauses)
         return clauses, adapter, add_to_collection, allow_innerjoin
 
     def _create_eager_join(self, context, entity, 
@@ -1055,6 +1093,7 @@ class JoinedLoader(AbstractRelationshipLoader):
             context.query._should_nest_selectable
 
         entity_key = None
+
         if entity not in context.eager_joins and \
             not should_nest_selectable and \
             context.from_clause:
@@ -1096,6 +1135,7 @@ class JoinedLoader(AbstractRelationshipLoader):
         else:
             onclause = self.parent_property
 
+        assert clauses.aliased_class is not None
         context.eager_joins[entity_key] = eagerjoin = \
                                 mapperutil.join(
                                             towrap, 
@@ -1134,12 +1174,12 @@ class JoinedLoader(AbstractRelationshipLoader):
                                 )
 
 
-    def _create_eager_adapter(self, context, row, adapter, path, reduced_path):
-        if ("user_defined_eager_row_processor", reduced_path) in \
-                                                    context.attributes:
-            decorator = context.attributes[
-                            ("user_defined_eager_row_processor",
-                            reduced_path)]
+    def _create_eager_adapter(self, context, row, adapter, path):
+        user_defined_adapter = path.get(context, 
+                                "user_defined_eager_row_processor", 
+                                False)
+        if user_defined_adapter is not False:
+            decorator = user_defined_adapter
             # user defined eagerloads are part of the "primary" 
             # portion of the load.
             # the adapters applied to the Query should be honored.
@@ -1147,11 +1187,10 @@ class JoinedLoader(AbstractRelationshipLoader):
                 decorator = decorator.wrap(context.adapter)
             elif context.adapter:
                 decorator = context.adapter
-        elif ("eager_row_processor", reduced_path) in context.attributes:
-            decorator = context.attributes[
-                            ("eager_row_processor", reduced_path)]
         else:
-            return False
+            decorator = path.get(context, "eager_row_processor")
+            if decorator is None:
+                return False
 
         try:
             self.mapper.identity_key_from_row(row, decorator)
@@ -1161,28 +1200,26 @@ class JoinedLoader(AbstractRelationshipLoader):
             # processor, will cause a degrade to lazy
             return False
 
-    def create_row_processor(self, context, path, reduced_path, mapper, row, adapter):
+    def create_row_processor(self, context, path, mapper, row, adapter):
         if not self.parent.class_manager[self.key].impl.supports_population:
             raise sa_exc.InvalidRequestError(
                         "'%s' does not support object "
                         "population - eager loading cannot be applied." % 
                         self)
 
-        our_path = path + (self.key,)
-        our_reduced_path = reduced_path + (self.key,)
+        our_path = path[self.key]
 
         eager_adapter = self._create_eager_adapter(
                                                 context, 
                                                 row, 
-                                                adapter, our_path,
-                                                our_reduced_path)
+                                                adapter, our_path)
 
         if eager_adapter is not False:
             key = self.key
+
             _instance = self.mapper._instance_processor(
                                 context, 
-                                our_path + (self.mapper,), 
-                                our_reduced_path + (self.mapper.base_mapper,),
+                                our_path[self.mapper],
                                 eager_adapter)
 
             if not self.uselist:
@@ -1193,8 +1230,7 @@ class JoinedLoader(AbstractRelationshipLoader):
             return self.parent_property.\
                             _get_strategy(LazyLoader).\
                             create_row_processor(
-                                            context, path, 
-                                            reduced_path,
+                                            context, path,
                                             mapper, row, adapter)
 
     def _create_collection_loader(self, context, key, _instance):
@@ -1279,19 +1315,18 @@ class EagerLazyOption(StrategizedOption):
     def get_strategy_class(self):
         return self.strategy_cls
 
+_factory = {
+    False:JoinedLoader,
+    "joined":JoinedLoader,
+    None:NoLoader,
+    "noload":NoLoader,
+    "select":LazyLoader,
+    True:LazyLoader,
+    "subquery":SubqueryLoader,
+    "immediate":ImmediateLoader
+}
 def factory(identifier):
-    if identifier is False or identifier == 'joined':
-        return JoinedLoader
-    elif identifier is None or identifier == 'noload':
-        return NoLoader
-    elif identifier is False or identifier == 'select':
-        return LazyLoader
-    elif identifier == 'subquery':
-        return SubqueryLoader
-    elif identifier == 'immediate':
-        return ImmediateLoader
-    else:
-        return LazyLoader
+    return _factory.get(identifier, LazyLoader)
 
 class EagerJoinOption(PropertyOption):
 
@@ -1300,12 +1335,12 @@ class EagerJoinOption(PropertyOption):
         self.innerjoin = innerjoin
         self.chained = chained
 
-    def process_query_property(self, query, paths, mappers):
+    def process_query_property(self, query, paths):
         if self.chained:
             for path in paths:
-                query._attributes[("eager_join_type", path)] = self.innerjoin
+                path.set(query, "eager_join_type", self.innerjoin)
         else:
-            query._attributes[("eager_join_type", paths[-1])] = self.innerjoin
+            paths[-1].set(query, "eager_join_type", self.innerjoin)
 
 class LoadEagerFromAliasOption(PropertyOption):
 
@@ -1313,36 +1348,41 @@ class LoadEagerFromAliasOption(PropertyOption):
         super(LoadEagerFromAliasOption, self).__init__(key)
         if alias is not None:
             if not isinstance(alias, basestring):
-                m, alias, is_aliased_class = mapperutil._entity_info(alias)
+                mapper, alias, is_aliased_class = \
+                        mapperutil._entity_info(alias)
         self.alias = alias
         self.chained = chained
 
-    def process_query_property(self, query, paths, mappers):
+    def process_query_property(self, query, paths):
         if self.chained:
             for path in paths[0:-1]:
-                (root_mapper, propname) = path[-2:]
+                (root_mapper, propname) = path.path[-2:]
                 prop = root_mapper._props[propname]
                 adapter = query._polymorphic_adapters.get(prop.mapper, None)
-                query._attributes.setdefault(
-                            ("user_defined_eager_row_processor", 
-                            interfaces._reduce_path(path)), adapter)
+                path.setdefault(query, 
+                            "user_defined_eager_row_processor", 
+                            adapter)
 
+        root_mapper, propname = paths[-1].path[-2:]
+        prop = root_mapper._props[propname]
         if self.alias is not None:
             if isinstance(self.alias, basestring):
-                (root_mapper, propname) = paths[-1][-2:]
-                prop = root_mapper._props[propname]
                 self.alias = prop.target.alias(self.alias)
-            query._attributes[
-                        ("user_defined_eager_row_processor"
-                        interfaces._reduce_path(paths[-1]))
-                        ] = sql_util.ColumnAdapter(self.alias)
+            paths[-1].set(query, "user_defined_eager_row_processor", 
+                sql_util.ColumnAdapter(self.alias
+                                equivalents=prop.mapper._equivalent_columns)
+            )
         else:
-            (root_mapper, propname) = paths[-1][-2:]
-            prop = root_mapper._props[propname]
-            adapter = query._polymorphic_adapters.get(prop.mapper, None)
-            query._attributes[
-                        ("user_defined_eager_row_processor", 
-                        interfaces._reduce_path(paths[-1]))] = adapter
+            if paths[-1].contains(query, "path_with_polymorphic"):
+                with_poly_info = paths[-1].get(query, "path_with_polymorphic")
+                adapter = mapperutil.ORMAdapter(
+                            with_poly_info.entity, 
+                            equivalents=prop.mapper._equivalent_columns,
+                            adapt_required=True)
+            else:
+                adapter = query._polymorphic_adapters.get(prop.mapper, None)
+            paths[-1].set(query, "user_defined_eager_row_processor", 
+                                    adapter)
 
 def single_parent_validator(desc, prop):
     def _do_check(state, value, oldvalue, initiator):
@@ -1363,6 +1403,8 @@ def single_parent_validator(desc, prop):
     def set_(state, value, oldvalue, initiator):
         return _do_check(state, value, oldvalue, initiator)
 
-    event.listen(desc, 'append', append, raw=True, retval=True, active_history=True)
-    event.listen(desc, 'set', set_, raw=True, retval=True, active_history=True)
+    event.listen(desc, 'append', append, raw=True, retval=True, 
+                            active_history=True)
+    event.listen(desc, 'set', set_, raw=True, retval=True, 
+                            active_history=True)
 
index 51aaa31524a9fdc964fba2d3c328e887ccadbbc0..0978ab693d61110eb20162ccd61ac3bbb2c21727 100644 (file)
@@ -9,6 +9,7 @@ from sqlalchemy import sql, util, event, exc as sa_exc, inspection
 from sqlalchemy.sql import expression, util as sql_util, operators
 from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE,\
                                 PropComparator, MapperProperty
+from itertools import chain
 from sqlalchemy.orm import attributes, exc
 import operator
 import re
@@ -233,6 +234,144 @@ class ORMAdapter(sql_util.ColumnAdapter):
         else:
             return None
 
+class PathRegistry(object):
+    """Represent query load paths and registry functions.
+
+    Basically represents structures like:
+
+    (<User mapper>, "orders", <Order mapper>, "items", <Item mapper>)
+
+    These structures are generated by things like
+    query options (joinedload(), subqueryload(), etc.) and are
+    used to compose keys stored in the query._attributes dictionary
+    for various options.
+
+    They are then re-composed at query compile/result row time as
+    the query is formed and as rows are fetched, where they again
+    serve to compose keys to look up options in the context.attributes
+    dictionary, which is copied from query._attributes.
+
+    The path structure has a limited amount of caching, where each
+    "root" ultimately pulls from a fixed registry associated with
+    the first mapper, that also contains elements for each of its 
+    property keys.  However paths longer than two elements, which 
+    are the exception rather than the rule, are generated on an 
+    as-needed basis.
+
+    """
+
+    def __eq__(self, other):
+        return other is not None and \
+            self.path == other.path
+
+    def set(self, reg, key, value):
+        reg._attributes[(key, self.reduced_path)] = value
+
+    def setdefault(self, reg, key, value):
+        reg._attributes.setdefault((key, self.reduced_path), value)
+
+    def get(self, reg, key, value=None):
+        key = (key, self.reduced_path)
+        if key in reg._attributes:
+            return reg._attributes[key]
+        else:
+            return value
+
+    @property
+    def length(self):
+        return len(self.path)
+
+    def contains_mapper(self, mapper):
+        return mapper.base_mapper in self.reduced_path
+
+    def contains(self, reg, key):
+        return (key, self.reduced_path) in reg._attributes
+
+    def serialize(self):
+        path = self.path
+        return zip(
+            [m.class_ for m in [path[i] for i in range(0, len(path), 2)]], 
+            [path[i] for i in range(1, len(path), 2)] + [None]
+        )
+
+    @classmethod
+    def deserialize(cls, path):
+        if path is None:
+            return None
+
+        p = tuple(chain(*[(class_mapper(mcls), key) for mcls, key in path]))
+        if p and p[-1] is None:
+            p = p[0:-1]
+        return cls.coerce(p)
+
+    @classmethod
+    def per_mapper(cls, mapper):
+        return EntityRegistry(
+                cls.root, mapper
+            )
+
+    @classmethod
+    def coerce(cls, raw):
+        return util.reduce(lambda prev, next:prev[next], raw, cls.root)
+
+    @classmethod
+    def token(cls, token):
+        return KeyRegistry(cls.root, token)
+
+    def __add__(self, other):
+        return util.reduce(
+                    lambda prev, next:prev[next],
+                    other.path, self)
+
+    def __repr__(self):
+        return "%s(%r)" % (self.__class__.__name__, self.path, )
+
+class RootRegistry(PathRegistry):
+    """Root registry, defers to mappers so that
+    paths are maintained per-root-mapper.
+
+    """
+    path = ()
+    reduced_path = ()
+
+    def __getitem__(self, mapper):
+        return mapper._sa_path_registry
+PathRegistry.root = RootRegistry()
+
+class KeyRegistry(PathRegistry):
+    def __init__(self, parent, key):
+        self.key = key
+        self.parent = parent
+        self.path = parent.path + (key,)
+        self.reduced_path = parent.reduced_path + (key,)
+
+    def __getitem__(self, entity):
+        return EntityRegistry(
+            self, entity
+        )
+
+class EntityRegistry(PathRegistry, dict):
+    is_aliased_class = False
+
+    def __init__(self, parent, entity):
+        self.key = reduced_key = entity
+        self.parent = parent
+        if hasattr(entity, 'base_mapper'):
+            reduced_key = entity.base_mapper
+        else:
+            self.is_aliased_class = True
+
+        self.path = parent.path + (entity,)
+        self.reduced_path = parent.reduced_path + (reduced_key,)
+
+    def __nonzero__(self):
+        return True
+
+    def __missing__(self, key):
+        self[key] = item = KeyRegistry(self, key)
+        return item
+
+
 class AliasedClass(object):
     """Represents an "aliased" form of a mapped class for usage with Query.
 
@@ -321,6 +460,10 @@ class AliasedClass(object):
         self._sa_label_name = name
         self.__name__ = 'AliasedClass_' + str(self.__target)
 
+    @util.memoized_property
+    def _sa_path_registry(self):
+        return PathRegistry.per_mapper(self)
+
     def __getstate__(self):
         return {
             'mapper':self.__mapper, 
@@ -408,7 +551,8 @@ def aliased(element, alias=None, name=None, adapt_on_names=False):
                     name=name, adapt_on_names=adapt_on_names)
 
 def with_polymorphic(base, classes, selectable=False, 
-                        polymorphic_on=None, aliased=False):
+                        polymorphic_on=None, aliased=False,
+                        innerjoin=False):
     """Produce an :class:`.AliasedClass` construct which specifies
     columns for descendant mappers of the given base.
 
@@ -422,23 +566,23 @@ def with_polymorphic(base, classes, selectable=False,
     criterion to be used against those tables.  The resulting
     instances will also have those columns already loaded so that
     no "post fetch" of those columns will be required.
-    
+
     See the examples at :ref:`with_polymorphic`.
 
     :param base: Base class to be aliased.
-    
+
     :param cls_or_mappers: a single class or mapper, or list of
         class/mappers, which inherit from the base class.
         Alternatively, it may also be the string ``'*'``, in which case
         all descending mapped classes will be added to the FROM clause.
-    
+
     :param aliased: when True, the selectable will be wrapped in an
         alias, that is ``(SELECT * FROM <fromclauses>) AS anon_1``.
         This can be important when using the with_polymorphic()
         to create the target of a JOIN on a backend that does not
         support parenthesized joins, such as SQLite and older
         versions of MySQL.
-        
+
     :param selectable: a table or select() statement that will
         be used in place of the generated FROM clause. This argument is
         required if any of the desired classes use concrete table
@@ -455,10 +599,12 @@ def with_polymorphic(base, classes, selectable=False,
         is useful for mappings that don't have polymorphic loading 
         behavior by default.
 
+    :param innerjoin: if True, an INNER JOIN will be used.  This should
+       only be specified if querying for one specific subtype only
     """
-    primary_mapper = class_mapper(base)
+    primary_mapper = _class_to_mapper(base)
     mappers, selectable = primary_mapper.\
-                    _with_polymorphic_args(classes, selectable)
+                    _with_polymorphic_args(classes, selectable, innerjoin=innerjoin)
     if aliased:
         selectable = selectable.alias()
     return AliasedClass(base, 
@@ -478,11 +624,11 @@ def _orm_annotate(element, exclude=None):
 
 def _orm_deannotate(element):
     """Remove annotations that link a column to a particular mapping.
-    
+
     Note this doesn't affect "remote" and "foreign" annotations
     passed by the :func:`.orm.foreign` and :func:`.orm.remote`
     annotators.
-    
+
     """
 
     return sql_util._deep_deannotate(element, 
@@ -644,13 +790,24 @@ def with_parent(instance, prop):
                         value_is_parent=True)
 
 
+_extended_entity_info_tuple = util.namedtuple("extended_entity_info", [
+    "entity",
+    "mapper",
+    "selectable",
+    "is_aliased_class",
+    "with_polymorphic_mappers",
+    "with_polymorphic_discriminator"
+])
 def _extended_entity_info(entity, compile=True):
     if isinstance(entity, AliasedClass):
-        return entity._AliasedClass__mapper, \
-                entity._AliasedClass__alias, \
-                True, \
-                entity._AliasedClass__with_polymorphic_mappers, \
-                entity._AliasedClass__with_polymorphic_discriminator
+        return _extended_entity_info_tuple(
+            entity,
+            entity._AliasedClass__mapper, \
+                    entity._AliasedClass__alias, \
+                    True, \
+                    entity._AliasedClass__with_polymorphic_mappers, \
+                    entity._AliasedClass__with_polymorphic_discriminator
+        )
 
     if isinstance(entity, mapperlib.Mapper):
         mapper = entity
@@ -659,19 +816,22 @@ def _extended_entity_info(entity, compile=True):
         class_manager = attributes.manager_of_class(entity)
 
         if class_manager is None:
-            return None, entity, False, [], None
+            return _extended_entity_info_tuple(entity, None, entity, False, [], None)
 
         mapper = class_manager.mapper
     else:
-        return None, entity, False, [], None
+        return _extended_entity_info_tuple(entity, None, entity, False, [], None)
 
     if compile and mapperlib.module._new_mappers:
         mapperlib.configure_mappers()
-    return mapper, \
+    return _extended_entity_info_tuple(
+        entity, 
+        mapper, \
             mapper._with_polymorphic_selectable, \
             False, \
             mapper._with_polymorphic_mappers, \
             mapper.polymorphic_on
+        )
 
 def _entity_info(entity, compile=True):
     """Return mapping information given a class, mapper, or AliasedClass.
@@ -684,7 +844,7 @@ def _entity_info(entity, compile=True):
     unmapped selectables through.
 
     """
-    return _extended_entity_info(entity, compile)[0:3]
+    return _extended_entity_info(entity, compile)[1:4]
 
 def _entity_descriptor(entity, key):
     """Return a class attribute given an entity and string name.
@@ -738,7 +898,7 @@ def object_mapper(instance):
     Raises UnmappedInstanceError if no mapping is configured.
 
     This function is available via the inspection system as::
-    
+
         inspect(instance).mapper
 
     """
@@ -752,7 +912,7 @@ def object_state(instance):
     Raises UnmappedInstanceError if no mapping is configured.
 
     This function is available via the inspection system as::
-    
+
         inspect(instance)
 
     """
@@ -776,9 +936,9 @@ def class_mapper(class_, compile=True):
     object is passed.
 
     This function is available via the inspection system as::
-    
+
         inspect(some_mapped_class)
-    
+
     """
 
     try:
index f836d7eafa26f1e967cd136977ef1fe4f517bb9b..bc0497bea04ef5a10088c9d62b8d00ad9f1c1aef 100644 (file)
@@ -3500,9 +3500,14 @@ class _Exists(_UnaryExpression):
     def select(self, whereclause=None, **params):
         return select([self], whereclause, **params)
 
-    def correlate(self, fromclause):
+    def correlate(self, *fromclause):
         e = self._clone()
-        e.element = self.element.correlate(fromclause).self_group()
+        e.element = self.element.correlate(*fromclause).self_group()
+        return e
+
+    def correlate_except(self, *fromclause):
+        e = self._clone()
+        e.element = self.element.correlate_except(*fromclause).self_group()
         return e
 
     def select_from(self, clause):
@@ -4708,7 +4713,8 @@ class Select(_SelectBase):
     _hints = util.immutabledict()
     _distinct = False
     _from_cloned = None
-
+    _correlate = ()
+    _correlate_except = ()
     _memoized_property = _SelectBase._memoized_property
 
     def __init__(self, 
@@ -4750,7 +4756,6 @@ class Select(_SelectBase):
                                 for e in util.to_list(distinct)
                             ]
 
-        self._correlate = set()
         if from_obj is not None:
             self._from_obj = util.OrderedSet(
                                 _literal_as_text(f) 
@@ -4837,10 +4842,13 @@ class Select(_SelectBase):
             # using a list to maintain ordering
             froms = [f for f in froms if f not in toremove]
 
-        if len(froms) > 1 or self._correlate:
+        if len(froms) > 1 or self._correlate or self._correlate_except:
             if self._correlate:
                 froms = [f for f in froms if f not in _cloned_intersection(froms,
                         self._correlate)]
+            if self._correlate_except:
+                froms = [f for f in froms if f in _cloned_intersection(froms, 
+                        self._correlate_except)]
             if self._should_correlate and existing_froms:
                 froms = [f for f in froms if f not in _cloned_intersection(froms,
                         existing_froms)]
@@ -5198,16 +5206,24 @@ class Select(_SelectBase):
         """
         self._should_correlate = False
         if fromclauses and fromclauses[0] is None:
-            self._correlate = set()
+            self._correlate = ()
+        else:
+            self._correlate = set(self._correlate).union(fromclauses)
+
+    @_generative
+    def correlate_except(self, *fromclauses):
+        self._should_correlate = False
+        if fromclauses and fromclauses[0] is None:
+            self._correlate_except = ()
         else:
-            self._correlate = self._correlate.union(fromclauses)
+            self._correlate_except = set(self._correlate_except).union(fromclauses)
 
     def append_correlation(self, fromclause):
         """append the given correlation expression to this select()
         construct."""
 
         self._should_correlate = False
-        self._correlate = self._correlate.union([fromclause])
+        self._correlate = set(self._correlate).union([fromclause])
 
     def append_column(self, column):
         """append the given column expression to the columns clause of this
index 76c3c829d91b7e47da1cb6525e2df25e2b37ab47..3cfe55f9cebdc624deb9ff63903adb45d1efc17c 100644 (file)
@@ -7,7 +7,7 @@
 from compat import callable, cmp, reduce, defaultdict, py25_dict, \
     threading, py3k_warning, jython, pypy, win32, set_types, buffer, pickle, \
     update_wrapper, partial, md5_hex, decode_slice, dottedgetter,\
-    parse_qsl, any, contextmanager
+    parse_qsl, any, contextmanager, namedtuple
 
 from _collections import NamedTuple, ImmutableContainer, immutabledict, \
     Properties, OrderedProperties, ImmutableProperties, OrderedDict, \
index 99b92b1e340a9ed80507bafe7dcf92a488797c34..c5339d013408c0994c0ee6ce95a50070bc85b3cc 100644 (file)
@@ -111,6 +111,18 @@ else:
     cmp = cmp
     reduce = reduce
 
+try:
+    from collections import namedtuple
+except ImportError:
+    def namedtuple(typename, fieldnames):
+        def __new__(cls, *values):
+            tup = tuple.__new__(tuptype, values)
+            for i, fname in enumerate(fieldnames):
+                setattr(tup, fname, tup[i])
+            return tup
+        tuptype = type(typename, (tuple, ), {'__new__':__new__})
+        return tuptype
+
 try:
     from collections import defaultdict
 except ImportError:
index 26e7d193c480007796c6da657887443a592d72fd..04cf82a15d44df6271328639ec0387ee68a14891 100644 (file)
@@ -1,6 +1,7 @@
 from test.lib.testing import eq_
 from sqlalchemy.orm import mapper, relationship, create_session, \
-    clear_mappers, sessionmaker, class_mapper
+    clear_mappers, sessionmaker, class_mapper, aliased,\
+    Session, subqueryload
 from sqlalchemy.orm.mapper import _mapper_registry
 from sqlalchemy.orm.session import _sessions
 import operator
@@ -22,40 +23,44 @@ class A(fixtures.ComparableEntity):
     pass
 class B(fixtures.ComparableEntity):
     pass
+class ASub(A):
+    pass
 
-def profile_memory(func):
-    # run the test 50 times.  if length of gc.get_objects()
-    # keeps growing, assert false
+def profile_memory(times=50):
+    def decorate(func):
+        # run the test 50 times.  if length of gc.get_objects()
+        # keeps growing, assert false
 
-    def profile(*args):
-        gc_collect()
-        samples = [0 for x in range(0, 50)]
-        for x in range(0, 50):
-            func(*args)
+        def profile(*args):
             gc_collect()
-            samples[x] = len(gc.get_objects())
-
-        print "sample gc sizes:", samples
+            samples = [0 for x in range(0, times)]
+            for x in range(0, times):
+                func(*args)
+                gc_collect()
+                samples[x] = len(gc.get_objects())
 
-        assert len(_sessions) == 0
+            print "sample gc sizes:", samples
 
-        for x in samples[-4:]:
-            if x != samples[-5]:
-                flatline = False
-                break
-        else:
-            flatline = True
+            assert len(_sessions) == 0
 
-        # object count is bigger than when it started
-        if not flatline and samples[-1] > samples[0]:
-            for x in samples[1:-2]:
-                # see if a spike bigger than the endpoint exists
-                if x > samples[-1]:
+            for x in samples[-4:]:
+                if x != samples[-5]:
+                    flatline = False
                     break
             else:
-                assert False, repr(samples) + " " + repr(flatline)
+                flatline = True
 
-    return profile
+            # object count is bigger than when it started
+            if not flatline and samples[-1] > samples[0]:
+                for x in samples[1:-2]:
+                    # see if a spike bigger than the endpoint exists
+                    if x > samples[-1]:
+                        break
+                else:
+                    assert False, repr(samples) + " " + repr(flatline)
+
+        return profile
+    return decorate
 
 def assert_no_mappers():
     clear_mappers()
@@ -78,7 +83,7 @@ class MemUsageTest(EnsureZeroed):
             pass
 
         x = []
-        @profile_memory
+        @profile_memory()
         def go():
             x[-1:] = [Foo(), Foo(), Foo(), Foo(), Foo(), Foo()]
         go()
@@ -107,7 +112,7 @@ class MemUsageTest(EnsureZeroed):
 
         m3 = mapper(A, table1, non_primary=True)
 
-        @profile_memory
+        @profile_memory()
         def go():
             sess = create_session()
             a1 = A(col2="a1")
@@ -168,7 +173,7 @@ class MemUsageTest(EnsureZeroed):
 
         m3 = mapper(A, table1, non_primary=True)
 
-        @profile_memory
+        @profile_memory()
         def go():
             engine = engines.testing_engine(
                                 options={'logging_name':'FOO',
@@ -227,7 +232,7 @@ class MemUsageTest(EnsureZeroed):
             (postgresql.INTERVAL, ),
             (mysql.VARCHAR, ),
         ):
-            @profile_memory
+            @profile_memory()
             def go():
                 type_ = args[0](*args[1:])
                 bp = type_._cached_bind_processor(eng.dialect)
@@ -260,7 +265,7 @@ class MemUsageTest(EnsureZeroed):
         del session
         counter = [1]
 
-        @profile_memory
+        @profile_memory()
         def go():
             session = create_session()
             w1 = session.query(Wide).first()
@@ -282,11 +287,6 @@ class MemUsageTest(EnsureZeroed):
         finally:
             metadata.drop_all()
 
-    @testing.fails_if(lambda : testing.db.dialect.name == 'sqlite' \
-                      and testing.db.dialect.dbapi.version_info >= (2,
-                      5),
-                      'Newer pysqlites generate warnings here too and '
-                      'have similar issues.')
     def test_unicode_warnings(self):
         metadata = MetaData(testing.db)
         table1 = Table('mytable', metadata, Column('col1', Integer,
@@ -296,8 +296,11 @@ class MemUsageTest(EnsureZeroed):
         metadata.create_all()
         i = [1]
 
+        # the times here is cranked way up so that we can see
+        # pysqlite clearing out it's internal buffer and allow
+        # the test to pass
         @testing.emits_warning()
-        @profile_memory
+        @profile_memory(times=220)
         def go():
 
             # execute with a non-unicode object. a warning is emitted,
@@ -325,7 +328,7 @@ class MemUsageTest(EnsureZeroed):
             Column('col2', String(30)),
             Column('col3', Integer, ForeignKey("mytable.col1")))
 
-        @profile_memory
+        @profile_memory()
         def go():
             m1 = mapper(A, table1, properties={
                 "bs":relationship(B, order_by=table2.c.col1)
@@ -368,6 +371,69 @@ class MemUsageTest(EnsureZeroed):
             metadata.drop_all()
         assert_no_mappers()
 
+    def test_alias_pathing(self):
+        metadata = MetaData(testing.db)
+
+        a = Table("a", metadata,
+            Column('id', Integer, primary_key=True,
+                                test_needs_autoincrement=True),
+            Column('bid', Integer, ForeignKey('b.id')),
+            Column('type', String(30))
+        )
+
+        asub = Table("asub", metadata,
+            Column('id', Integer, ForeignKey('a.id'),
+                                primary_key=True),
+            Column('data', String(30)))
+
+        b = Table("b", metadata,
+            Column('id', Integer, primary_key=True,
+                                test_needs_autoincrement=True),
+        )
+        mapper(A, a, polymorphic_identity='a',
+            polymorphic_on=a.c.type)
+        mapper(ASub, asub, inherits=A,polymorphic_identity='asub')
+        m1 = mapper(B, b, properties={
+            'as_':relationship(A)
+        })
+
+        metadata.create_all()
+        sess = Session()
+        a1 = ASub(data="a1")
+        a2 = ASub(data="a2")
+        a3 = ASub(data="a3")
+        b1 = B(as_=[a1, a2, a3])
+        sess.add(b1)
+        sess.commit()
+        del sess
+
+        # sqlite has a slow enough growth here
+        # that we have to run it more times to see the
+        # "dip" again
+        @profile_memory(times=120)
+        def go():
+            sess = Session()
+            sess.query(B).options(subqueryload(B.as_.of_type(ASub))).all()
+            sess.close()
+        try:
+            go()
+        finally:
+            metadata.drop_all()
+        clear_mappers()
+
+    def test_path_registry(self):
+        metadata = MetaData()
+        a = Table("a", metadata,
+            Column('id', Integer, primary_key=True),
+        )
+        m1 = mapper(A, a)
+        @profile_memory()
+        def go():
+            ma = aliased(A)
+            m1._sa_path_registry['foo'][ma]['bar']
+        go()
+        clear_mappers()
+
     def test_with_inheritance(self):
         metadata = MetaData(testing.db)
 
@@ -383,7 +449,7 @@ class MemUsageTest(EnsureZeroed):
             Column('col3', String(30)),
             )
 
-        @profile_memory
+        @profile_memory()
         def go():
             class A(fixtures.ComparableEntity):
                 pass
@@ -449,7 +515,7 @@ class MemUsageTest(EnsureZeroed):
             Column('t2', Integer, ForeignKey('mytable2.col1')),
             )
 
-        @profile_memory
+        @profile_memory()
         def go():
             class A(fixtures.ComparableEntity):
                 pass
@@ -505,7 +571,7 @@ class MemUsageTest(EnsureZeroed):
         t = Table('t', m, Column('x', Integer), Column('y', Integer))
         m.create_all(e)
         e.execute(t.insert(), {"x":1, "y":1})
-        @profile_memory
+        @profile_memory()
         def go():
             r = e.execute(t.alias().select())
             for row in r:
@@ -541,7 +607,7 @@ class MemUsageTest(EnsureZeroed):
         metadata.create_all()
         session = sessionmaker()
 
-        @profile_memory
+        @profile_memory()
         def go():
             s = table2.select()
             sess = session()
@@ -557,7 +623,7 @@ class MemUsageTest(EnsureZeroed):
     def test_type_compile(self):
         from sqlalchemy.dialects.sqlite.base import dialect as SQLiteDialect
         cast = sa.cast(column('x'), sa.Integer)
-        @profile_memory
+        @profile_memory()
         def go():
             dialect = SQLiteDialect()
             cast.compile(dialect=dialect)
@@ -565,21 +631,21 @@ class MemUsageTest(EnsureZeroed):
 
     @testing.requires.cextensions
     def test_DecimalResultProcessor_init(self):
-        @profile_memory
+        @profile_memory()
         def go():
             to_decimal_processor_factory({}, 10)
         go()
 
     @testing.requires.cextensions
     def test_DecimalResultProcessor_process(self):
-        @profile_memory
+        @profile_memory()
         def go():
             to_decimal_processor_factory(decimal.Decimal, 10)(1.2)
         go()
 
     @testing.requires.cextensions
     def test_UnicodeResultProcessor_init(self):
-        @profile_memory
+        @profile_memory()
         def go():
             to_unicode_processor_factory('utf8')
         go()
index 0ee8f47e77084084cfe71aabaa4908abf950f0b9..6c502f58ddd3d232925a5d02f027cce2a5e981de 100644 (file)
@@ -154,6 +154,54 @@ class _PolymorphicFixtureBase(fixtures.MappedTest, AssertsCompiledSQL):
         cls.c1_employees = [e1, e2, b1, m1]
         cls.c2_employees = [e3]
 
+    def _company_with_emps_machines_fixture(self):
+        fixture = self._company_with_emps_fixture()
+        fixture[0].employees[0].machines = [
+            Machine(name="IBM ThinkPad"),
+            Machine(name="IPhone"),
+        ]
+        fixture[0].employees[1].machines = [
+            Machine(name="Commodore 64")
+        ]
+        return fixture
+
+    def _company_with_emps_fixture(self):
+        return [
+            Company(
+                name="MegaCorp, Inc.",
+                employees=[
+                    Engineer(
+                        name="dilbert",
+                        engineer_name="dilbert",
+                        primary_language="java",
+                        status="regular engineer"
+                    ),
+                    Engineer(
+                        name="wally",
+                        engineer_name="wally",
+                        primary_language="c++",
+                        status="regular engineer"),
+                    Boss(
+                        name="pointy haired boss",
+                        golf_swing="fore",
+                        manager_name="pointy",
+                        status="da boss"),
+                    Manager(
+                        name="dogbert",
+                        manager_name="dogbert",
+                        status="regular manager"),
+                ]),
+            Company(
+                name="Elbonia, Inc.",
+                employees=[
+                    Engineer(
+                        name="vlad",
+                        engineer_name="vlad",
+                        primary_language="cobol",
+                        status="elbonian engineer")
+                ])
+        ]
+
     def _emps_wo_relationships_fixture(self):
         return [
             Engineer(
index 792bd81099568d6270758935c864896f1b25aa6f..50593d39c74e81f018631a9d5cc01e88c0199ea0 100644 (file)
@@ -1,17 +1,14 @@
-from sqlalchemy import Integer, String, ForeignKey, func, desc, and_, or_
-from sqlalchemy.orm import interfaces, relationship, mapper, \
-    clear_mappers, create_session, joinedload, joinedload_all, \
-    subqueryload, subqueryload_all, polymorphic_union, aliased,\
+from sqlalchemy import func, desc
+from sqlalchemy.orm import interfaces, create_session, joinedload, joinedload_all, \
+    subqueryload, subqueryload_all, aliased,\
     class_mapper
 from sqlalchemy import exc as sa_exc
-from sqlalchemy.engine import default
 
-from test.lib import AssertsCompiledSQL, fixtures, testing
-from test.lib.schema import Table, Column
+from test.lib import testing
 from test.lib.testing import assert_raises, eq_
 
 from _poly_fixtures import Company, Person, Engineer, Manager, Boss, \
-    Machine, Paperwork, _PolymorphicFixtureBase, _Polymorphic,\
+    Machine, Paperwork, _Polymorphic,\
     _PolymorphicPolymorphic, _PolymorphicUnions, _PolymorphicJoins,\
     _PolymorphicAliasedJoins
 
@@ -323,31 +320,6 @@ class _PolymorphicTestBase(object):
                 .filter(any_).all(),
             [])
 
-    def test_polymorphic_any_four(self):
-        sess = create_session()
-        any_ = Company.employees.of_type(Engineer).any(
-            Engineer.primary_language == 'cobol')
-        eq_(sess.query(Company).filter(any_).one(), c2)
-
-    def test_polymorphic_any_five(self):
-        sess = create_session()
-        calias = aliased(Company)
-        any_ = calias.employees.of_type(Engineer).any(
-            Engineer.primary_language == 'cobol')
-        eq_(sess.query(calias).filter(any_).one(), c2)
-
-    def test_polymorphic_any_six(self):
-        sess = create_session()
-        any_ = Company.employees.of_type(Boss).any(
-            Boss.golf_swing == 'fore')
-        eq_(sess.query(Company).filter(any_).one(), c1)
-
-    def test_polymorphic_any_seven(self):
-        sess = create_session()
-        any_ = Company.employees.of_type(Boss).any(
-            Manager.manager_name == 'pointy')
-        eq_(sess.query(Company).filter(any_).one(), c1)
-
     def test_polymorphic_any_eight(self):
         sess = create_session()
         any_ = Engineer.machines.any(
@@ -360,11 +332,6 @@ class _PolymorphicTestBase(object):
             Paperwork.description == "review #2")
         eq_(sess.query(Person).filter(any_).all(), [m1])
 
-    def test_polymorphic_any_ten(self):
-        sess = create_session()
-        any_ = Company.employees.of_type(Engineer).any(
-            and_(Engineer.primary_language == 'cobol'))
-        eq_(sess.query(Company).filter(any_).one(), c2)
 
     def test_join_from_columns_or_subclass_one(self):
         sess = create_session()
@@ -529,17 +496,21 @@ class _PolymorphicTestBase(object):
                 .all(),
             expected)
 
+    # TODO: this fails due to the change
+    # in _configure_subclass_mapper.  however we might not
+    # need it anymore.
     def test_polymorphic_option(self):
         """
         Test that polymorphic loading sets state.load_path with its 
         actual mapper on a subclass, and not the superclass mapper.
-        """
 
+        This only works for non-aliased mappers.
+        """
         paths = []
         class MyOption(interfaces.MapperOption):
             propagate_to_loaders = True
             def process_query_conditionally(self, query):
-                paths.append(query._current_path)
+                paths.append(query._current_path.path)
 
         sess = create_session()
         names = ['dilbert', 'pointy haired boss']
@@ -556,6 +527,17 @@ class _PolymorphicTestBase(object):
             [(class_mapper(Engineer), 'machines'),
             (class_mapper(Boss), 'paperwork')])
 
+    def test_subclass_option_pathing(self):
+        from sqlalchemy.orm import defer
+        sess = create_session()
+        names = ['dilbert', 'pointy haired boss']
+        dilbert = sess.query(Person).\
+                options(defer(Engineer.machines, Machine.name)).\
+                filter(Person.name == 'dilbert').first()
+        m = dilbert.machines[0]
+        assert 'name' not in m.__dict__
+        eq_(m.name, 'IBM ThinkPad')
+
     def test_expire(self):
         """
         Test that individual column refresh doesn't get tripped up by 
@@ -639,69 +621,42 @@ class _PolymorphicTestBase(object):
             self._emps_wo_relationships_fixture())
 
 
-    def test_relationship_to_polymorphic(self):
-        expected = [
-            Company(
-                name="MegaCorp, Inc.",
-                employees=[
-                    Engineer(
-                        name="dilbert",
-                        engineer_name="dilbert",
-                        primary_language="java",
-                        status="regular engineer",
-                        machines=[
-                            Machine(name="IBM ThinkPad"),
-                            Machine(name="IPhone")]),
-                    Engineer(
-                        name="wally",
-                        engineer_name="wally",
-                        primary_language="c++",
-                        status="regular engineer"),
-                    Boss(
-                        name="pointy haired boss",
-                        golf_swing="fore",
-                        manager_name="pointy",
-                        status="da boss"),
-                    Manager(
-                        name="dogbert",
-                        manager_name="dogbert",
-                        status="regular manager"),
-                ]),
-            Company(
-                name="Elbonia, Inc.",
-                employees=[
-                    Engineer(
-                        name="vlad",
-                        engineer_name="vlad",
-                        primary_language="cobol",
-                        status="elbonian engineer")
-                ])
-        ]
-
+    def test_relationship_to_polymorphic_one(self):
+        expected = self._company_with_emps_machines_fixture()
         sess = create_session()
         def go():
             # test load Companies with lazy load to 'employees'
             eq_(sess.query(Company).all(), expected)
-        count = {'':9, 'Polymorphic':4}.get(self.select_type, 5)
+        count = {'':10, 'Polymorphic':5}.get(self.select_type, 6)
         self.assert_sql_count(testing.db, go, count)
 
+    def test_relationship_to_polymorphic_two(self):
+        expected = self._company_with_emps_machines_fixture()
         sess = create_session()
         def go():
-            # currently, it doesn't matter if we say Company.employees, 
-            # or Company.employees.of_type(Engineer).  joinedloader 
-            # doesn't pick up on the "of_type()" as of yet.
+            # with #2438, of_type() is recognized.  This
+            # overrides the with_polymorphic of the mapper
+            # and we get a consistent 3 queries now.
             eq_(sess.query(Company)
                     .options(joinedload_all(
                         Company.employees.of_type(Engineer),
                         Engineer.machines))
                     .all(),
                 expected)
-        # in the case of select_type='', the joinedload 
-        # doesn't take in this case; it joinedloads company->people, 
-        # then a load for each of 5 rows, then lazyload of "machines"
-        count = {'':7, 'Polymorphic':1}.get(self.select_type, 2)
+
+        # in the old case, we would get this
+        #count = {'':7, 'Polymorphic':1}.get(self.select_type, 2)
+
+        # query one is company->Person/Engineer->Machines
+        # query two is managers + boss for row #3
+        # query three is managers for row #4
+        count = 3
         self.assert_sql_count(testing.db, go, count)
 
+    def test_relationship_to_polymorphic_three(self):
+        expected = self._company_with_emps_machines_fixture()
+        sess = create_session()
+
         sess = create_session()
         def go():
             eq_(sess.query(Company)
@@ -710,12 +665,20 @@ class _PolymorphicTestBase(object):
                         Engineer.machines))
                     .all(),
                 expected)
-        count = {
-            '':8,
-            'Joins':4,
-            'Unions':4,
-            'Polymorphic':3,
-            'AliasedJoins':4}[self.select_type]
+
+        # the old case where subqueryload_all
+        # didn't work with of_tyoe
+        #count = { '':8, 'Joins':4, 'Unions':4, 'Polymorphic':3,
+        #    'AliasedJoins':4}[self.select_type]
+
+        # query one is company->Person/Engineer->Machines
+        # query two is Person/Engineer subq
+        # query three is Machines subq 
+        # (however this test can't tell if the Q was a 
+        # lazyload or subqload ...)
+        # query four is managers + boss for row #3
+        # query five is managers for row #4
+        count = 5
         self.assert_sql_count(testing.db, go, count)
 
     def test_joinedload_on_subclass(self):
@@ -869,40 +832,6 @@ class _PolymorphicTestBase(object):
                 .filter(Machine.name.ilike("%ibm%")).all(),
             [e1, e3])
 
-    def test_join_to_subclass_eightteen(self):
-        sess = create_session()
-        # here's the new way
-        eq_(sess.query(Company)
-                .join(Company.employees.of_type(Engineer))
-                .filter(Engineer.primary_language == 'java').all(),
-            [c1])
-
-    def test_join_to_subclass_nineteen(self):
-        sess = create_session()
-        eq_(sess.query(Company)
-                .join(Company.employees.of_type(Engineer), 'machines')
-                .filter(Machine.name.ilike("%thinkpad%")).all(),
-            [c1])
-
-    def test_join_to_subclass_count(self):
-        sess = create_session()
-
-        eq_(sess.query(Company, Engineer)
-                .join(Company.employees.of_type(Engineer))
-                .filter(Engineer.primary_language == 'java').count(),
-            1)
-
-        # test [ticket:2093]
-        eq_(sess.query(Company.company_id, Engineer)
-                .join(Company.employees.of_type(Engineer))
-                .filter(Engineer.primary_language == 'java').count(),
-            1)
-
-        eq_(sess.query(Company)
-                .join(Company.employees.of_type(Engineer))
-                .filter(Engineer.primary_language == 'java').count(),
-            1)
-
     def test_join_through_polymorphic_nonaliased_one(self):
         sess = create_session()
         eq_(sess.query(Company)
@@ -1364,6 +1293,5 @@ class PolymorphicUnionsTest(_PolymorphicTestBase, _PolymorphicUnions):
 class PolymorphicAliasedJoinsTest(_PolymorphicTestBase, _PolymorphicAliasedJoins):
     pass
 
-
 class PolymorphicJoinsTest(_PolymorphicTestBase, _PolymorphicJoins):
     pass
index 5885a4bda771a1d070cdd3ea7ef1498ac38adf79..951b3ec79a6bc84a7ee4003597b13ca52434ec0f 100644 (file)
@@ -1078,7 +1078,7 @@ class MergeTest(_fixtures.FixtureTest):
 
         for u in s1_users:
             ustate = attributes.instance_state(u)
-            eq_(ustate.load_path, (umapper, ))
+            eq_(ustate.load_path.path, (umapper, ))
             eq_(ustate.load_options, set([opt2]))
 
         # test 2.  present options are replaced by merge options
@@ -1086,7 +1086,7 @@ class MergeTest(_fixtures.FixtureTest):
         s1_users = sess.query(User).options(opt1).all()
         for u in s1_users:
             ustate = attributes.instance_state(u)
-            eq_(ustate.load_path, (umapper, ))
+            eq_(ustate.load_path.path, (umapper, ))
             eq_(ustate.load_options, set([opt1]))
 
         for u in s2_users:
@@ -1094,7 +1094,7 @@ class MergeTest(_fixtures.FixtureTest):
 
         for u in s1_users:
             ustate = attributes.instance_state(u)
-            eq_(ustate.load_path, (umapper, ))
+            eq_(ustate.load_path.path, (umapper, ))
             eq_(ustate.load_options, set([opt2]))
 
 
index f2d292832e3658a196a9c901bbf5090cbadb18e3..bb5bca984d3c032810ab7a14346009ac03558719 100644 (file)
@@ -108,6 +108,8 @@ class PickleTest(fixtures.MappedTest):
         eq_(str(u1), "User(name='ed')")
 
     def test_serialize_path(self):
+        from sqlalchemy.orm.util import PathRegistry
+
         users, addresses = (self.tables.users,
                                 self.tables.addresses)
 
@@ -117,24 +119,24 @@ class PickleTest(fixtures.MappedTest):
         amapper = mapper(Address, addresses)
 
         # this is a "relationship" path with mapper, key, mapper, key
-        p1 = (umapper, 'addresses', amapper, 'email_address')
+        p1 = PathRegistry.coerce((umapper, 'addresses', amapper, 'email_address'))
         eq_(
-            interfaces.deserialize_path(interfaces.serialize_path(p1)),
+            PathRegistry.deserialize(p1.serialize()),
             p1
         )
 
         # this is a "mapper" path with mapper, key, mapper, no key
         # at the end.
-        p2 = (umapper, 'addresses', amapper, )
+        p2 = PathRegistry.coerce((umapper, 'addresses', amapper, ))
         eq_(
-            interfaces.deserialize_path(interfaces.serialize_path(p2)),
+            PathRegistry.deserialize(p2.serialize()),
             p2
         )
 
         # test a blank path
-        p3 = ()
+        p3 = PathRegistry.root
         eq_(
-            interfaces.deserialize_path(interfaces.serialize_path(p3)),
+            PathRegistry.deserialize(p3.serialize()),
             p3
         )
 
index 9073f105698aecf040bab19cdf55549fe8e00a89..fcda72a8a3103f81cbe63223c34be6bdd368deec 100644 (file)
@@ -9,7 +9,8 @@ from sqlalchemy.sql import expression
 from sqlalchemy.engine import default
 from sqlalchemy.orm import attributes, mapper, relationship, backref, \
     configure_mappers, create_session, synonym, Session, class_mapper, \
-    aliased, column_property, joinedload_all, joinedload, Query
+    aliased, column_property, joinedload_all, joinedload, Query,\
+    util as orm_util
 from test.lib.assertsql import CompiledSQL
 from test.lib.testing import eq_
 from test.lib.schema import Table, Column
@@ -2119,7 +2120,7 @@ class ExecutionOptionsTest(QueryTest):
 
 
 class OptionsTest(QueryTest):
-    """Test the _get_paths() method of PropertyOption."""
+    """Test the _process_paths() method of PropertyOption."""
 
     def _option_fixture(self, *arg):
         from sqlalchemy.orm import interfaces
@@ -2136,11 +2137,15 @@ class OptionsTest(QueryTest):
             r.append(item)
         return tuple(r)
 
-    def _assert_path_result(self, opt, q, paths, mappers):
+    def _make_path_registry(self, path):
+        return orm_util.PathRegistry.coerce(self._make_path(path))
+
+    def _assert_path_result(self, opt, q, paths):
+        q._attributes = q._attributes.copy()
+        assert_paths = opt._process_paths(q, False)
         eq_(
-            opt._get_paths(q, False),
-            ([self._make_path(p) for p in paths], 
-            [class_mapper(c) for c in mappers])
+            [p.path for p in assert_paths],
+            [self._make_path(p) for p in paths]
         )
 
     def test_get_path_one_level_string(self):
@@ -2150,7 +2155,7 @@ class OptionsTest(QueryTest):
         q = sess.query(User)
 
         opt = self._option_fixture("addresses")
-        self._assert_path_result(opt, q, [(User, 'addresses')], [User])
+        self._assert_path_result(opt, q, [(User, 'addresses')])
 
     def test_get_path_one_level_attribute(self):
         User = self.classes.User
@@ -2159,7 +2164,7 @@ class OptionsTest(QueryTest):
         q = sess.query(User)
 
         opt = self._option_fixture(User.addresses)
-        self._assert_path_result(opt, q, [(User, 'addresses')], [User])
+        self._assert_path_result(opt, q, [(User, 'addresses')])
 
     def test_path_on_entity_but_doesnt_match_currentpath(self):
         User, Address = self.classes.User, self.classes.Address
@@ -2170,8 +2175,10 @@ class OptionsTest(QueryTest):
         sess = Session()
         q = sess.query(User)
         opt = self._option_fixture('email_address', 'id')
-        q = sess.query(Address)._with_current_path([class_mapper(User), 'addresses'])
-        self._assert_path_result(opt, q, [], [])
+        q = sess.query(Address)._with_current_path(
+                orm_util.PathRegistry.coerce([class_mapper(User), 'addresses'])
+            )
+        self._assert_path_result(opt, q, [])
 
     def test_get_path_one_level_with_unrelated(self):
         Order = self.classes.Order
@@ -2179,7 +2186,7 @@ class OptionsTest(QueryTest):
         sess = Session()
         q = sess.query(Order)
         opt = self._option_fixture("addresses")
-        self._assert_path_result(opt, q, [], [])
+        self._assert_path_result(opt, q, [])
 
     def test_path_multilevel_string(self):
         Item, User, Order = (self.classes.Item,
@@ -2194,8 +2201,7 @@ class OptionsTest(QueryTest):
             (User, 'orders'), 
             (User, 'orders', Order, 'items'),
             (User, 'orders', Order, 'items', Item, 'keywords')
-        ], 
-        [User, Order, Item])
+        ])
 
     def test_path_multilevel_attribute(self):
         Item, User, Order = (self.classes.Item,
@@ -2210,8 +2216,7 @@ class OptionsTest(QueryTest):
             (User, 'orders'), 
             (User, 'orders', Order, 'items'),
             (User, 'orders', Order, 'items', Item, 'keywords')
-        ], 
-        [User, Order, Item])
+        ])
 
     def test_with_current_matching_string(self):
         Item, User, Order = (self.classes.Item,
@@ -2220,13 +2225,13 @@ class OptionsTest(QueryTest):
 
         sess = Session()
         q = sess.query(Item)._with_current_path(
-                self._make_path([User, 'orders', Order, 'items'])
+                self._make_path_registry([User, 'orders', Order, 'items'])
             )
 
         opt = self._option_fixture("orders.items.keywords")
         self._assert_path_result(opt, q, [
             (Item, 'keywords')
-        ], [Item])
+        ])
 
     def test_with_current_matching_attribute(self):
         Item, User, Order = (self.classes.Item,
@@ -2235,13 +2240,13 @@ class OptionsTest(QueryTest):
 
         sess = Session()
         q = sess.query(Item)._with_current_path(
-                self._make_path([User, 'orders', Order, 'items'])
+                self._make_path_registry([User, 'orders', Order, 'items'])
             )
 
         opt = self._option_fixture(User.orders, Order.items, Item.keywords)
         self._assert_path_result(opt, q, [
             (Item, 'keywords')
-        ], [Item])
+        ])
 
     def test_with_current_nonmatching_string(self):
         Item, User, Order = (self.classes.Item,
@@ -2250,14 +2255,14 @@ class OptionsTest(QueryTest):
 
         sess = Session()
         q = sess.query(Item)._with_current_path(
-                self._make_path([User, 'orders', Order, 'items'])
+                self._make_path_registry([User, 'orders', Order, 'items'])
             )
 
         opt = self._option_fixture("keywords")
-        self._assert_path_result(opt, q, [], [])
+        self._assert_path_result(opt, q, [])
 
         opt = self._option_fixture("items.keywords")
-        self._assert_path_result(opt, q, [], [])
+        self._assert_path_result(opt, q, [])
 
     def test_with_current_nonmatching_attribute(self):
         Item, User, Order = (self.classes.Item,
@@ -2266,14 +2271,14 @@ class OptionsTest(QueryTest):
 
         sess = Session()
         q = sess.query(Item)._with_current_path(
-                self._make_path([User, 'orders', Order, 'items'])
+                self._make_path_registry([User, 'orders', Order, 'items'])
             )
 
         opt = self._option_fixture(Item.keywords)
-        self._assert_path_result(opt, q, [], [])
+        self._assert_path_result(opt, q, [])
 
         opt = self._option_fixture(Order.items, Item.keywords)
-        self._assert_path_result(opt, q, [], [])
+        self._assert_path_result(opt, q, [])
 
     def test_from_base_to_subclass_attr(self):
         Dingaling, Address = self.classes.Dingaling, self.classes.Address
@@ -2288,7 +2293,7 @@ class OptionsTest(QueryTest):
         q = sess.query(Address)
         opt = self._option_fixture(SubAddr.flub)
 
-        self._assert_path_result(opt, q, [(Address, 'flub')], [SubAddr])
+        self._assert_path_result(opt, q, [(Address, 'flub')])
 
     def test_from_subclass_to_subclass_attr(self):
         Dingaling, Address = self.classes.Dingaling, self.classes.Address
@@ -2303,7 +2308,7 @@ class OptionsTest(QueryTest):
         q = sess.query(SubAddr)
         opt = self._option_fixture(SubAddr.flub)
 
-        self._assert_path_result(opt, q, [(SubAddr, 'flub')], [SubAddr])
+        self._assert_path_result(opt, q, [(SubAddr, 'flub')])
 
     def test_from_base_to_base_attr_via_subclass(self):
         Dingaling, Address = self.classes.Dingaling, self.classes.Address
@@ -2318,7 +2323,7 @@ class OptionsTest(QueryTest):
         q = sess.query(Address)
         opt = self._option_fixture(SubAddr.user)
 
-        self._assert_path_result(opt, q, [(Address, 'user')], [Address])
+        self._assert_path_result(opt, q, [(Address, 'user')])
 
     def test_of_type(self):
         User, Address = self.classes.User, self.classes.Address
@@ -2334,7 +2339,7 @@ class OptionsTest(QueryTest):
         self._assert_path_result(opt, q, [
             (User, 'addresses'),
             (User, 'addresses', SubAddr, 'user')
-        ], [User, Address])
+        ])
 
     def test_of_type_plus_level(self):
         Dingaling, User, Address = (self.classes.Dingaling,
@@ -2354,7 +2359,7 @@ class OptionsTest(QueryTest):
         self._assert_path_result(opt, q, [
             (User, 'addresses'),
             (User, 'addresses', SubAddr, 'flub')
-        ], [User, SubAddr])
+        ])
 
     def test_aliased_single(self):
         User = self.classes.User
@@ -2363,7 +2368,7 @@ class OptionsTest(QueryTest):
         ualias = aliased(User)
         q = sess.query(ualias)
         opt = self._option_fixture(ualias.addresses)
-        self._assert_path_result(opt, q, [(ualias, 'addresses')], [User])
+        self._assert_path_result(opt, q, [(ualias, 'addresses')])
 
     def test_with_current_aliased_single(self):
         User, Address = self.classes.User, self.classes.Address
@@ -2371,10 +2376,10 @@ class OptionsTest(QueryTest):
         sess = Session()
         ualias = aliased(User)
         q = sess.query(ualias)._with_current_path(
-                        self._make_path([Address, 'user'])
+                        self._make_path_registry([Address, 'user'])
                 )
         opt = self._option_fixture(Address.user, ualias.addresses)
-        self._assert_path_result(opt, q, [(ualias, 'addresses')], [User])
+        self._assert_path_result(opt, q, [(ualias, 'addresses')])
 
     def test_with_current_aliased_single_nonmatching_option(self):
         User, Address = self.classes.User, self.classes.Address
@@ -2382,22 +2387,21 @@ class OptionsTest(QueryTest):
         sess = Session()
         ualias = aliased(User)
         q = sess.query(User)._with_current_path(
-                        self._make_path([Address, 'user'])
+                        self._make_path_registry([Address, 'user'])
                 )
         opt = self._option_fixture(Address.user, ualias.addresses)
-        self._assert_path_result(opt, q, [], [])
+        self._assert_path_result(opt, q, [])
 
-    @testing.fails_if(lambda: True, "Broken feature")
     def test_with_current_aliased_single_nonmatching_entity(self):
         User, Address = self.classes.User, self.classes.Address
 
         sess = Session()
         ualias = aliased(User)
         q = sess.query(ualias)._with_current_path(
-                        self._make_path([Address, 'user'])
+                        self._make_path_registry([Address, 'user'])
                 )
         opt = self._option_fixture(Address.user, User.addresses)
-        self._assert_path_result(opt, q, [], [])
+        self._assert_path_result(opt, q, [])
 
     def test_multi_entity_opt_on_second(self):
         Item = self.classes.Item
@@ -2405,7 +2409,7 @@ class OptionsTest(QueryTest):
         opt = self._option_fixture(Order.items)
         sess = Session()
         q = sess.query(Item, Order)
-        self._assert_path_result(opt, q, [(Order, "items")], [Order])
+        self._assert_path_result(opt, q, [(Order, "items")])
 
     def test_multi_entity_opt_on_string(self):
         Item = self.classes.Item
@@ -2413,7 +2417,7 @@ class OptionsTest(QueryTest):
         opt = self._option_fixture("items")
         sess = Session()
         q = sess.query(Item, Order)
-        self._assert_path_result(opt, q, [], [])
+        self._assert_path_result(opt, q, [])
 
     def test_multi_entity_no_mapped_entities(self):
         Item = self.classes.Item
@@ -2421,7 +2425,7 @@ class OptionsTest(QueryTest):
         opt = self._option_fixture("items")
         sess = Session()
         q = sess.query(Item.id, Order.id)
-        self._assert_path_result(opt, q, [], [])
+        self._assert_path_result(opt, q, [])
 
     def test_path_exhausted(self):
         User = self.classes.User
@@ -2430,9 +2434,9 @@ class OptionsTest(QueryTest):
         opt = self._option_fixture(User.orders)
         sess = Session()
         q = sess.query(Item)._with_current_path(
-                        self._make_path([User, 'orders', Order, 'items'])
+                        self._make_path_registry([User, 'orders', Order, 'items'])
                 )
-        self._assert_path_result(opt, q, [], [])
+        self._assert_path_result(opt, q, [])
 
 class OptionsNoPropTest(_fixtures.FixtureTest):
     """test the error messages emitted when using property
index 90df17609dc0f9d67f53eb994558cd3ae53825a7..53c50634eb0e9822d86022c9e3a418ce4340d704 100644 (file)
@@ -358,8 +358,6 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
         ("subqueryload", "subqueryload", "subqueryload", 4),
         ("subqueryload", "subqueryload", "joinedload", 3),
     ]
-#    _pathing_runs = [("subqueryload", "subqueryload", "joinedload", 3)]
-#    _pathing_runs = [("subqueryload", "subqueryload", "subqueryload", 4)]
 
     def test_options_pathing(self):
         self._do_options_test(self._pathing_runs)
index d24376cc9c9fd39081ab7646c083b20dddc108fe..23bad9c6969182fd01bac64588ed91d7f48d72d6 100644 (file)
@@ -159,7 +159,9 @@ print 'Total executemany calls: %d' \
     % counts_by_methname.get("<method 'executemany' of 'sqlite3.Cursor' "
                          "objects>", 0)
 
-os.system("runsnake %s" % filename)
+#stats.sort_stats('time', 'calls')
+#stats.print_stats()
+#os.system("runsnake %s" % filename)
 
 # SQLA Version: 0.7b1
 # Total calls 4956750