]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- applying some refined versions of the ideas in the smarter_polymorphic
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 14 Jan 2008 02:45:30 +0000 (02:45 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 14 Jan 2008 02:45:30 +0000 (02:45 +0000)
branch
- slowly moving Query towards a central "aliasing" paradigm which merges
the aliasing of polymorphic mappers to aliasing against arbitrary select_from(),
to the eventual goal of polymorphic mappers which can also eagerload other
relations
- supports many more join() scenarios involving polymorphic mappers in
most configurations
- PropertyAliasedClauses doesn't need "path", EagerLoader doesn't need to
guess about "towrap"

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/util.py
test/orm/eager_relations.py
test/orm/inheritance/query.py

diff --git a/CHANGES b/CHANGES
index a460f53d4cac4f56b0f46896ed5f2a64858a6f09..3f3f178aea0e521eda29022a4b0fe8de5af379a9 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -19,6 +19,11 @@ CHANGES
       of being deferred until later.  This mimics the old 0.3
       behavior.
 
+    - general improvements to the behavior of join() in 
+      conjunction with polymorphic mappers, i.e. joining
+      from/to polymorphic mappers and properly applying 
+      aliases
+      
     - fixed bug in polymorphic inheritance which made it 
       difficult to set a working "order_by" on a polymorphic
       mapper
index 84a9bfeab3998ca233bf8a577a989cccfa9c14e2..c733c68ad2a7b131847305087d3ed8fb0dd919a4 100644 (file)
@@ -118,7 +118,8 @@ class Mapper(object):
         self._eager_loaders = util.Set()
         self._row_translators = {}
         self._dependency_processors = []
-
+        self._clause_adapter = None
+        
         # our 'polymorphic identity', a string name that when located in a result set row
         # indicates this Mapper should be used to construct the object instance for that row.
         self.polymorphic_identity = polymorphic_identity
@@ -738,6 +739,7 @@ class Mapper(object):
                     elif (isinstance(prop, list) and expression.is_column(prop[0])):
                         self.__surrogate_mapper.add_property(key, [_corresponding_column_or_error(self.select_table, c) for c in prop])
             
+            self.__surrogate_mapper._clause_adapter = adapter
 
     def _compile_class(self):
         """If this mapper is to be a primary mapper (i.e. the
index 19af7b47370a97a0ca49915e0ecc133179883da3..ca430378b87a77a0df7fea2e6543d2967401244f 100644 (file)
@@ -18,6 +18,7 @@ from sqlalchemy.orm import session as sessionlib
 from sqlalchemy.orm.util import CascadeOptions
 from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty
 from sqlalchemy.exceptions import ArgumentError
+import weakref
 
 __all__ = ('ColumnProperty', 'CompositeProperty', 'SynonymProperty',
            'PropertyLoader', 'BackRef')
@@ -207,7 +208,7 @@ class PropertyLoader(StrategizedProperty):
         self.passive_updates = passive_updates
         self.remote_side = util.to_set(remote_side)
         self.enable_typechecks = enable_typechecks
-        self._parent_join_cache = {}
+        self.__parent_join_cache = weakref.WeakKeyDictionary()
         self.comparator = PropertyLoader.Comparator(self)
         self.join_depth = join_depth
         self.strategy_class = strategy_class
@@ -681,51 +682,66 @@ class PropertyLoader(StrategizedProperty):
     def _is_self_referential(self):
         return self.parent.mapped_table is self.target or self.parent.select_table is self.target
 
-    def get_join(self, parent, primary=True, secondary=True, polymorphic_parent=True):
-        """return a join condition from the given parent mapper to this PropertyLoader's mapper.
-
-           The resulting ClauseElement object is cached and should not be modified directly.
-
-            parent
-              a mapper which has a relation() to this PropertyLoader.  A PropertyLoader can
-              have multiple "parents" when its actual parent mapper has inheriting mappers.
-
-            primary
-              include the primary join condition in the resulting join.
-
-            secondary
-              include the secondary join condition in the resulting join.  If both primary
-              and secondary are returned, they are joined via AND.
-
-            polymorphic_parent
-              if True, use the parent's 'select_table' instead of its 'mapped_table' to produce the join.
-        """
-
+    def primary_join_against(self, mapper, selectable=None):
+        return self.__cached_join_against(mapper, selectable, True, False)
+        
+    def secondary_join_against(self, mapper):
+        return self.__cached_join_against(mapper, None, False, True)
+        
+    def full_join_against(self, mapper, selectable=None):
+        return self.__cached_join_against(mapper, selectable, True, True)
+    
+    def __cached_join_against(self, mapper, selectable, primary, secondary):
+        if selectable is None:
+            selectable = mapper.local_table
+            
         try:
-            return self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)]
+            rec = self.__parent_join_cache[selectable]
         except KeyError:
-            parent_equivalents = parent._equivalent_columns
-            secondaryjoin = self.polymorphic_secondaryjoin
-            if polymorphic_parent:
-                # adapt the "parent" side of our join condition to the "polymorphic" select of the parent
+            self.__parent_join_cache[selectable] = rec = {}
+
+        key = (mapper, primary, secondary)
+        if key in rec:
+            return rec[key]
+        
+        parent_equivalents = mapper._equivalent_columns
+        
+        if primary:
+            if selectable is not mapper.local_table:
                 if self.direction is sync.ONETOMANY:
-                    primaryjoin = ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
+                    primaryjoin = ClauseAdapter(selectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin)
                 elif self.direction is sync.MANYTOONE:
-                    primaryjoin = ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
+                    primaryjoin = ClauseAdapter(selectable, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin)
                 elif self.secondaryjoin:
-                    primaryjoin = ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
-
-            if secondaryjoin is not None:
-                if secondary and not primary:
-                    j = secondaryjoin
-                elif primary and secondary:
-                    j = primaryjoin & secondaryjoin
-                elif primary and not secondary:
-                    j = primaryjoin
+                    primaryjoin = ClauseAdapter(selectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin)
+            else:
+                primaryjoin = self.polymorphic_primaryjoin
+                
+            if secondary:
+                secondaryjoin = self.polymorphic_secondaryjoin
+                rec[key] = ret = primaryjoin & secondaryjoin
             else:
-                j = primaryjoin
-            self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)] = j
-            return j
+                rec[key] = ret = primaryjoin
+            return ret
+        
+        elif secondary:
+            rec[key] = ret = self.polymorphic_secondaryjoin
+            return ret
+
+        else:
+            raise AssertionError("illegal condition")
+        
+    def get_join(self, parent, primary=True, secondary=True, polymorphic_parent=True):
+        """deprecated.  use primary_join_against(), secondary_join_against(), full_join_against()"""
+        
+        if primary and secondary:
+            return self.full_join_against(parent, parent.select_table)
+        elif primary:
+            return self.primary_join_against(parent, parent.select_table)
+        elif secondary:
+            return self.secondary_join_against(parent)
+        else:
+            raise AssertionError("illegal condition")
 
     def register_dependencies(self, uowcommit):
         if not self.viewonly:
index f651f04345644b144895d78ffd6507df3e2fd0d4..b3678f1aa071f244e1e159bf76252f0c16555fdc 100644 (file)
@@ -53,6 +53,7 @@ class Query(object):
         self._params = {}
         self._yield_per = None
         self._criterion = None
+        self._joinable_tables = None
         self._having = None
         self._column_aggregate = None
         self._joinpoint = self.mapper
@@ -64,12 +65,12 @@ class Query(object):
         self._autoflush = True
         self._eager_loaders = util.Set(chain(*[mp._eager_loaders for mp in [m for m in self.mapper.iterate_to_root()]]))
         self._attributes = {}
-        self.__joinable_tables = {}
         self._current_path = ()
-        self._primary_adapter=None
         self._only_load_props = None
         self._refresh_instance = None
-
+        
+        self._adapter = self.select_mapper._clause_adapter
+        
     def _no_criterion(self, meth):
         q = self._clone()
 
@@ -79,6 +80,7 @@ class Query(object):
                  "criterion is being ignored.") % meth)
 
         q._from_obj = self.table
+        q._adapter = self.select_mapper._clause_adapter
         q._alias_ids = {}
         q._joinpoint = self.mapper
         q._statement = q._aliases = q._criterion = None
@@ -357,7 +359,7 @@ class Query(object):
         q._params = q._params.copy()
         q._params.update(kwargs)
         return q
-
+    
     def filter(self, criterion):
         """apply the given filtering criterion to the query and return the newly resulting ``Query``
 
@@ -370,12 +372,9 @@ class Query(object):
         if criterion is not None and not isinstance(criterion, sql.ClauseElement):
             raise exceptions.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string")
 
-
-        if self._aliases is not None:
-            criterion = self._aliases.adapt_clause(criterion)
-        elif self.table not in self._get_joinable_tables():
-            criterion = sql_util.ClauseAdapter(self._from_obj).traverse(criterion)
-
+        if self._adapter is not None:
+            criterion = self._adapter.traverse(criterion)
+            
         q = self._no_statement("filter")
         if q._criterion is not None:
             q._criterion = q._criterion & criterion
@@ -392,14 +391,16 @@ class Query(object):
         return self.filter(sql.and_(*clauses))
 
     def _get_joinable_tables(self):
-        if self._from_obj not in self.__joinable_tables:
+        if not self._joinable_tables or self._joinable_tables[0] is not self._from_obj:
             currenttables = [self._from_obj]
             def visit_join(join):
                 currenttables.append(join.left)
                 currenttables.append(join.right)
             visitors.traverse(self._from_obj, visit_join=visit_join, traverse_options={'column_collections':False, 'aliased_selectables':False})
-            self.__joinable_tables = {self._from_obj : currenttables}
-        return self.__joinable_tables[self._from_obj]
+            self._joinable_tables = (self._from_obj, currenttables)
+            return currenttables
+        else:
+            return self._joinable_tables[1]
 
     def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True):
         if start is None:
@@ -408,7 +409,15 @@ class Query(object):
         clause = self._from_obj
 
         currenttables = self._get_joinable_tables()
-        adapt_criterion = self.table not in currenttables
+
+        # determine if generated joins need to be aliased on the left
+        # hand side.  
+        if self._adapter and not self._aliases:  # at the beginning of a join, look at leftmost adapter
+            adapt_against = self._adapter.selectable
+        elif start.select_table is not start.mapped_table: # in the middle of a join, look for a polymorphic mapper
+            adapt_against = start.select_table
+        else:
+            adapt_against = None
 
         mapper = start
         alias = self._aliases
@@ -421,35 +430,27 @@ class Query(object):
                 if prop.secondary:
                     if create_aliases:
                         alias = mapperutil.PropertyAliasedClauses(prop,
-                            prop.get_join(mapper, primary=True, secondary=False),
-                            prop.get_join(mapper, primary=False, secondary=True),
+                            prop.primary_join_against(mapper, adapt_against),
+                            prop.secondary_join_against(mapper),
                             alias
                         )
                         crit = alias.primaryjoin
-                        if adapt_criterion:
-                            crit = sql_util.ClauseAdapter(clause).traverse(crit)
                         clause = clause.join(alias.secondary, crit, isouter=outerjoin).join(alias.alias, alias.secondaryjoin, isouter=outerjoin)
                     else:
-                        crit = prop.get_join(mapper, primary=True, secondary=False)
-                        if adapt_criterion:
-                            crit = sql_util.ClauseAdapter(clause).traverse(crit)
+                        crit = prop.primary_join_against(mapper, adapt_against)
                         clause = clause.join(prop.secondary, crit, isouter=outerjoin)
-                        clause = clause.join(prop.select_table, prop.get_join(mapper, primary=False), isouter=outerjoin)
+                        clause = clause.join(prop.select_table, prop.secondary_join_against(mapper), isouter=outerjoin)
                 else:
                     if create_aliases:
                         alias = mapperutil.PropertyAliasedClauses(prop,
-                            prop.get_join(mapper, primary=True, secondary=False),
+                            prop.primary_join_against(mapper, adapt_against), 
                             None,
                             alias
                         )
                         crit = alias.primaryjoin
-                        if adapt_criterion:
-                            crit = sql_util.ClauseAdapter(clause).traverse(crit)
                         clause = clause.join(alias.alias, crit, isouter=outerjoin)
                     else:
-                        crit = prop.get_join(mapper)
-                        if adapt_criterion:
-                            crit = sql_util.ClauseAdapter(clause).traverse(crit)
+                        crit = prop.primary_join_against(mapper, adapt_against)
                         clause = clause.join(prop.select_table, crit, isouter=outerjoin)
             elif not create_aliases and prop.secondary is not None and prop.secondary not in currenttables:
                 # TODO: this check is not strong enough for different paths to the same endpoint which
@@ -458,6 +459,9 @@ class Query(object):
 
             mapper = prop.mapper
 
+            if mapper.select_table is not mapper.mapped_table:
+                adapt_against = mapper.select_table
+
         if create_aliases:
             return (clause, mapper, alias)
         else:
@@ -539,9 +543,9 @@ class Query(object):
 
         q = self._no_statement("order_by")
 
-        if self._aliases is not None:
+        if self._adapter:
             criterion = [expression._literal_as_text(o) for o in util.to_list(criterion) or []]
-            criterion = self._aliases.adapt_list(criterion)
+            criterion = self._adapter.copy_and_process(criterion)
 
         if q._order_by is False:
             q._order_by = util.to_list(criterion)
@@ -568,9 +572,8 @@ class Query(object):
         if criterion is not None and not isinstance(criterion, sql.ClauseElement):
             raise exceptions.ArgumentError("having() argument must be of type sqlalchemy.sql.ClauseElement or string")
 
-
-        if self._aliases is not None:
-            criterion = self._aliases.adapt_clause(criterion)
+        if self._adapter is not None:
+            criterion = self._adapter.traverse(criterion)
 
         q = self._no_statement("having")
         if q._having is not None:
@@ -605,6 +608,13 @@ class Query(object):
         q._from_obj = clause
         q._joinpoint = mapper
         q._aliases = aliases
+        
+        if aliases:
+            q._adapter = sql_util.ClauseAdapter(aliases.alias).copy_and_chain(q._adapter)
+        else:
+            select_mapper = mapper.get_select_mapper()
+            if select_mapper._clause_adapter:
+                q._adapter = select_mapper._clause_adapter.copy_and_chain(q._adapter)
 
         a = aliases
         while a is not None:
@@ -629,6 +639,8 @@ class Query(object):
         q = self._no_statement("reset_joinpoint")
         q._joinpoint = q.mapper
         q._aliases = None
+        if q.table not in q._get_joinable_tables():
+            q._adapter = sql_util.ClauseAdapter(q._from_obj, equivalents=q.mapper._equivalent_columns)
         return q
 
 
@@ -651,6 +663,9 @@ class Query(object):
             from_obj = from_obj.alias()
 
         new._from_obj = from_obj
+
+        if new.table not in new._get_joinable_tables():
+            new._adapter = sql_util.ClauseAdapter(new._from_obj, equivalents=new.mapper._equivalent_columns)
         return new
 
     def __getitem__(self, item):
@@ -787,9 +802,9 @@ class Query(object):
         mappers_or_columns = tuple(self._entities) + mappers_or_columns
         tuples = bool(mappers_or_columns)
 
-        if self._primary_adapter:
+        if context.row_adapter:
             def main(context, row):
-                return self.select_mapper._instance(context, self._primary_adapter(row), None,
+                return self.select_mapper._instance(context, context.row_adapter(row), None,
                     extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance
                 )
         else:
@@ -957,17 +972,18 @@ class Query(object):
 
         from_obj = self._from_obj
 
-        # indicates if the "from" clause of the query does not include
-        # the normally mapped table, i.e. the user issued select_from(somestatement)
-        # or similar.  all clauses which derive from the mapped table will need to
-        # be adapted to be relative to the user-supplied selectable.
-        adapt_criterion = self.table not in self._get_joinable_tables()
-
-        # adapt for poylmorphic mapper
-        # TODO: generalize the polymorphic mapper adaption to that of the select_from() adaption
-        if not adapt_criterion and whereclause is not None and (self.mapper is not self.select_mapper):
-            whereclause = sql_util.ClauseAdapter(from_obj, equivalents=self.select_mapper._equivalent_columns).traverse(whereclause)
+        # if the query's ClauseAdapter is present, and its
+        # specifically adapting against a modified "select_from"
+        # argument, apply adaptation to the
+        # individually selected columns as well as "eager" clauses added;
+        # otherwise its currently not needed
+        if self._adapter and self.table not in self._get_joinable_tables():
+            adapter = self._adapter
+        else:
+            adapter = None
 
+        adapter = self._adapter
+        
         # TODO: mappers added via add_entity(), adapt their queries also,
         # if those mappers are polymorphic
 
@@ -1029,7 +1045,9 @@ class Query(object):
                 for o in order_by:
                     cf.update(sql_util.find_columns(o))
 
-            if adapt_criterion:
+            if adapter:
+                # TODO: make usage of the ClauseAdapter here to create the list
+                # of primary columns
                 context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns]
                 cf = [from_obj.corresponding_column(c) or c for c in cf]
 
@@ -1037,7 +1055,7 @@ class Query(object):
 
             s3 = s2.alias()
 
-            self._primary_adapter = mapperutil.create_row_adapter(s3, self.table)
+            context.row_adapter = mapperutil.create_row_adapter(s3, self.table)
 
             statement = sql.select([s3] + context.secondary_columns, for_update=for_update, use_labels=True)
 
@@ -1050,17 +1068,16 @@ class Query(object):
 
             statement.append_order_by(*context.eager_order_by)
         else:
-            if adapt_criterion:
+            if adapter:
+                # TODO: make usage of the ClauseAdapter here to create row adapter, list
+                # of primary columns
                 context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns]
-                self._primary_adapter = mapperutil.create_row_adapter(from_obj, self.table)
+                context.row_adapter = mapperutil.create_row_adapter(from_obj, self.table)
 
-            if adapt_criterion or self._distinct:
+            if self._distinct:
                 if order_by:
                     order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []]
 
-                    if adapt_criterion:
-                        order_by = sql_util.ClauseAdapter(from_obj).copy_and_process(order_by)
-
                 if self._distinct and order_by:
                     cf = util.Set()
                     for o in order_by:
@@ -1071,13 +1088,13 @@ class Query(object):
             statement = sql.select(context.primary_columns + context.secondary_columns, whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, order_by=util.to_list(order_by), **self._select_args())
             
             if context.eager_joins:
-                if adapt_criterion:
-                    context.eager_joins = sql_util.ClauseAdapter(from_obj).traverse(context.eager_joins)
+                if adapter:
+                    context.eager_joins = adapter.traverse(context.eager_joins)
                 statement.append_from(context.eager_joins, _copy_collection=False)
 
             if context.eager_order_by:
-                if adapt_criterion:
-                    context.eager_order_by = sql_util.ClauseAdapter(from_obj).copy_and_process(context.eager_order_by)
+                if adapter:
+                    context.eager_order_by = adapter.copy_and_process(context.eager_order_by)
                 statement.append_order_by(*context.eager_order_by)
 
         context.statement = statement
@@ -1103,6 +1120,7 @@ class Query(object):
                 return self._alias_ids[alias_id]
             except KeyError:
                 raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % alias_id)
+                
         if isinstance(m, type):
             m = mapper.class_mapper(m)
         if isinstance(m, mapper.Mapper):
@@ -1369,6 +1387,7 @@ class QueryContext(object):
         self.session = query.session
         self.extension = query._extension
         self.statement = None
+        self.row_adapter = None
         self.populate_existing = query._populate_existing
         self.version_check = query._version_check
         self.only_load_props = query._only_load_props
index a715d924a1faa74bc718bda574bc75985db1f2a0..908c43feb15a1f619f47dbab1b917347c98fbde8 100644 (file)
@@ -519,10 +519,7 @@ class EagerLoader(AbstractRelationLoader):
         if context.eager_joins:
             towrap = context.eager_joins
         else:
-            if isinstance(context.from_clause, sql.Join):
-                towrap = context.from_clause
-            else:
-                towrap = localparent.mapped_table
+            towrap = context.from_clause
         
         # create AliasedClauses object to build up the eager query.  this is cached after 1st creation.    
         try:
index 7473609d74d88787f889d971d92e413c3a46b88f..4f2ab5444af1e6e9c0a88e86326f1ee86eb3d009 100644 (file)
@@ -236,10 +236,6 @@ class PropertyAliasedClauses(AliasedClauses):
         super(PropertyAliasedClauses, self).__init__(prop.select_table)
             
         self.parentclauses = parentclauses
-        if parentclauses is not None:
-            self.path = build_path(prop.parent, prop.key, parentclauses.path)
-        else:
-            self.path = build_path(prop.parent, prop.key)
 
         self.prop = prop
         
@@ -261,6 +257,7 @@ class PropertyAliasedClauses(AliasedClauses):
                 aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, exclude=prop.remote_side))
             else:
                 aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side)
+
             self.primaryjoin = aliasizer.traverse(primaryjoin, clone=True)
             self.secondary = None
             self.secondaryjoin = None
@@ -273,9 +270,6 @@ class PropertyAliasedClauses(AliasedClauses):
     mapper = property(lambda self:self.prop.mapper)
     table = property(lambda self:self.prop.select_table)
     
-    def __str__(self):
-        return "->".join([str(s) for s in self.path])
-
 
 def instance_str(instance):
     """Return a string describing an instance."""
index b45c0425c8dc1c793cfdecaafddd3a286172537d..c2ac26557ee06c96dcdcc124323f63a07194b9c8 100644 (file)
@@ -186,6 +186,25 @@ class ClauseAdapter(AbstractClauseProcessor):
         self.exclude = exclude
         self.equivalents = equivalents
 
+    def copy_and_chain(self, adapter):
+        """create a copy of this adapter and chain to the given adapter.
+        
+        currently this adapter must be unchained to start, raises
+        an exception if it's already chained.  
+        
+        Does not modify the given adapter.
+        """
+        
+        if adapter is None:
+            return self
+            
+        if hasattr(self, '_next_acp') or hasattr(self, '_next'):
+            raise NotImplementedError("Can't chain_to on an already chained ClauseAdapter (yet)")
+            
+        ca = ClauseAdapter(self.selectable, self.include, self.exclude, self.equivalents)
+        ca._next_acp = adapter
+        return ca
+        
     def convert_element(self, col):
         if isinstance(col, expression.FromClause):
             if self.selectable.is_derived_from(col):
index e42ef5cb8172e057349160978f552703dab38fa7..f35fbcbfc28e5481431a0bf4e0f9f573ee59f6e2 100644 (file)
@@ -195,6 +195,10 @@ class EagerTest(FixtureTest):
             assert fixtures.item_keyword_result[0:2] == q.join('keywords').filter(keywords.c.name == 'red').all()
         self.assert_sql_count(testing.db, go, 1)
 
+        def go():
+            assert fixtures.item_keyword_result[0:2] == q.join('keywords', aliased=True).filter(keywords.c.name == 'red').all()
+        self.assert_sql_count(testing.db, go, 1)
+
 
     def test_eager_option(self):
         mapper(Keyword, keywords)
index 698df33fa70d7ccc0330b7f7d534572242eb2b35..2a15ae1b0194a3f7af8ee19793aa6f9e9aa0077a 100644 (file)
@@ -93,7 +93,7 @@ class PolymorphicQueryTest(ORMTest):
         mapper(Paperwork, paperwork)
 
     def insert_data(self):
-        global all_employees, c1_employees, c2_employees, e1, e2, b1, m1, e3
+        global all_employees, c1_employees, c2_employees, e1, e2, b1, m1, e3, c1, c2
 
         c1 = Company(name="MegaCorp, Inc.")
         c2 = Company(name="Elbonia, Inc.")
@@ -114,7 +114,9 @@ class PolymorphicQueryTest(ORMTest):
         ])
         c1.employees = [e1, e2, b1, m1]
 
-        e3 = Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer")
+        e3 = Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer", paperwork=[
+            Paperwork(description='elbonian missive #3')
+        ])
         c2.employees = [e3]
         sess = create_session()
         sess.save(c1)
@@ -127,9 +129,6 @@ class PolymorphicQueryTest(ORMTest):
         c2_employees = [e3]
 
     def test_filter_on_subclass(self):
-        print Manager.person_id == Engineer.person_id
-        print Manager.c.person_id == Engineer.c.person_id
-        
         sess = create_session()
         self.assertEquals(sess.query(Engineer).all()[0], Engineer(name="dilbert"))
 
@@ -142,12 +141,74 @@ class PolymorphicQueryTest(ORMTest):
         self.assertEquals(sess.query(Manager).filter(Manager.person_id==b1.person_id).one(), Boss(name="pointy haired boss"))
         
         self.assertEquals(sess.query(Boss).filter(Boss.person_id==b1.person_id).one(), Boss(name="pointy haired boss"))
+
+    def test_join_from_polymorphic(self):
+        sess = create_session()
         
-    def test_load_all(self):
+        for aliased in (True, False):
+            self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1])
+
+            self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1])
+
+            self.assertEquals(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1])
+
+            self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
+    
+    def test_join_to_polymorphic(self):
+        sess = create_session()
+        self.assertEquals(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2)
+
+        self.assertEquals(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2)
+    
+    def test_join_through_polymorphic(self):
+        sess = create_session()
+
+        for aliased in (True, False):
+            self.assertEquals(
+                sess.query(Company).\
+                    join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#2%')).all(),
+                [c1]
+            )
+
+            self.assertEquals(
+                sess.query(Company).\
+                    join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#%')).all(),
+                [c1, c2]
+            )
+
+            self.assertEquals(
+                sess.query(Company).\
+                    join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#2%')).all(),
+                [c1]
+            )
+        
+            self.assertEquals(
+                sess.query(Company).\
+                    join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#%')).all(),
+                [c1, c2]
+            )
+
+            self.assertEquals(
+                sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\
+                    join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#2%')).all(),
+                [c1]
+            )
+
+            self.assertEquals(
+                sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\
+                    join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#%')).all(),
+                [c1, c2]
+            )
+        
+    def test_filter_on_baseclass(self):
         sess = create_session()
 
         self.assertEquals(sess.query(Person).all(), all_employees)
 
+        self.assertEquals(sess.query(Person).first(), all_employees[0])
+        
+        self.assertEquals(sess.query(Person).filter(Person.person_id==e2.person_id).one(), e2)
+
 
 if __name__ == "__main__":
     testenv.main()