From 9e1a35ef3daaee6590830ae5f2c0c9045d682b9d Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 14 Jan 2008 02:45:30 +0000 Subject: [PATCH] - applying some refined versions of the ideas in the smarter_polymorphic branch - slowly moving Query towards a central "aliasing" paradigm which merges the aliasing of polymorphic mappers to aliasing against arbitrary select_from(), to the eventual goal of polymorphic mappers which can also eagerload other relations - supports many more join() scenarios involving polymorphic mappers in most configurations - PropertyAliasedClauses doesn't need "path", EagerLoader doesn't need to guess about "towrap" --- CHANGES | 5 ++ lib/sqlalchemy/orm/mapper.py | 4 +- lib/sqlalchemy/orm/properties.py | 96 ++++++++++++---------- lib/sqlalchemy/orm/query.py | 133 ++++++++++++++++++------------- lib/sqlalchemy/orm/strategies.py | 5 +- lib/sqlalchemy/orm/util.py | 8 +- lib/sqlalchemy/sql/util.py | 19 +++++ test/orm/eager_relations.py | 4 + test/orm/inheritance/query.py | 73 +++++++++++++++-- 9 files changed, 232 insertions(+), 115 deletions(-) diff --git a/CHANGES b/CHANGES index a460f53d4c..3f3f178aea 100644 --- a/CHANGES +++ b/CHANGES @@ -19,6 +19,11 @@ CHANGES of being deferred until later. This mimics the old 0.3 behavior. + - general improvements to the behavior of join() in + conjunction with polymorphic mappers, i.e. joining + from/to polymorphic mappers and properly applying + aliases + - fixed bug in polymorphic inheritance which made it difficult to set a working "order_by" on a polymorphic mapper diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 84a9bfeab3..c733c68ad2 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -118,7 +118,8 @@ class Mapper(object): self._eager_loaders = util.Set() self._row_translators = {} self._dependency_processors = [] - + self._clause_adapter = None + # our 'polymorphic identity', a string name that when located in a result set row # indicates this Mapper should be used to construct the object instance for that row. self.polymorphic_identity = polymorphic_identity @@ -738,6 +739,7 @@ class Mapper(object): elif (isinstance(prop, list) and expression.is_column(prop[0])): self.__surrogate_mapper.add_property(key, [_corresponding_column_or_error(self.select_table, c) for c in prop]) + self.__surrogate_mapper._clause_adapter = adapter def _compile_class(self): """If this mapper is to be a primary mapper (i.e. the diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 19af7b4737..ca430378b8 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -18,6 +18,7 @@ from sqlalchemy.orm import session as sessionlib from sqlalchemy.orm.util import CascadeOptions from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty from sqlalchemy.exceptions import ArgumentError +import weakref __all__ = ('ColumnProperty', 'CompositeProperty', 'SynonymProperty', 'PropertyLoader', 'BackRef') @@ -207,7 +208,7 @@ class PropertyLoader(StrategizedProperty): self.passive_updates = passive_updates self.remote_side = util.to_set(remote_side) self.enable_typechecks = enable_typechecks - self._parent_join_cache = {} + self.__parent_join_cache = weakref.WeakKeyDictionary() self.comparator = PropertyLoader.Comparator(self) self.join_depth = join_depth self.strategy_class = strategy_class @@ -681,51 +682,66 @@ class PropertyLoader(StrategizedProperty): def _is_self_referential(self): return self.parent.mapped_table is self.target or self.parent.select_table is self.target - def get_join(self, parent, primary=True, secondary=True, polymorphic_parent=True): - """return a join condition from the given parent mapper to this PropertyLoader's mapper. - - The resulting ClauseElement object is cached and should not be modified directly. - - parent - a mapper which has a relation() to this PropertyLoader. A PropertyLoader can - have multiple "parents" when its actual parent mapper has inheriting mappers. - - primary - include the primary join condition in the resulting join. - - secondary - include the secondary join condition in the resulting join. If both primary - and secondary are returned, they are joined via AND. - - polymorphic_parent - if True, use the parent's 'select_table' instead of its 'mapped_table' to produce the join. - """ - + def primary_join_against(self, mapper, selectable=None): + return self.__cached_join_against(mapper, selectable, True, False) + + def secondary_join_against(self, mapper): + return self.__cached_join_against(mapper, None, False, True) + + def full_join_against(self, mapper, selectable=None): + return self.__cached_join_against(mapper, selectable, True, True) + + def __cached_join_against(self, mapper, selectable, primary, secondary): + if selectable is None: + selectable = mapper.local_table + try: - return self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)] + rec = self.__parent_join_cache[selectable] except KeyError: - parent_equivalents = parent._equivalent_columns - secondaryjoin = self.polymorphic_secondaryjoin - if polymorphic_parent: - # adapt the "parent" side of our join condition to the "polymorphic" select of the parent + self.__parent_join_cache[selectable] = rec = {} + + key = (mapper, primary, secondary) + if key in rec: + return rec[key] + + parent_equivalents = mapper._equivalent_columns + + if primary: + if selectable is not mapper.local_table: if self.direction is sync.ONETOMANY: - primaryjoin = ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) + primaryjoin = ClauseAdapter(selectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin) elif self.direction is sync.MANYTOONE: - primaryjoin = ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) + primaryjoin = ClauseAdapter(selectable, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin) elif self.secondaryjoin: - primaryjoin = ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) - - if secondaryjoin is not None: - if secondary and not primary: - j = secondaryjoin - elif primary and secondary: - j = primaryjoin & secondaryjoin - elif primary and not secondary: - j = primaryjoin + primaryjoin = ClauseAdapter(selectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin) + else: + primaryjoin = self.polymorphic_primaryjoin + + if secondary: + secondaryjoin = self.polymorphic_secondaryjoin + rec[key] = ret = primaryjoin & secondaryjoin else: - j = primaryjoin - self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)] = j - return j + rec[key] = ret = primaryjoin + return ret + + elif secondary: + rec[key] = ret = self.polymorphic_secondaryjoin + return ret + + else: + raise AssertionError("illegal condition") + + def get_join(self, parent, primary=True, secondary=True, polymorphic_parent=True): + """deprecated. use primary_join_against(), secondary_join_against(), full_join_against()""" + + if primary and secondary: + return self.full_join_against(parent, parent.select_table) + elif primary: + return self.primary_join_against(parent, parent.select_table) + elif secondary: + return self.secondary_join_against(parent) + else: + raise AssertionError("illegal condition") def register_dependencies(self, uowcommit): if not self.viewonly: diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index f651f04345..b3678f1aa0 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -53,6 +53,7 @@ class Query(object): self._params = {} self._yield_per = None self._criterion = None + self._joinable_tables = None self._having = None self._column_aggregate = None self._joinpoint = self.mapper @@ -64,12 +65,12 @@ class Query(object): self._autoflush = True self._eager_loaders = util.Set(chain(*[mp._eager_loaders for mp in [m for m in self.mapper.iterate_to_root()]])) self._attributes = {} - self.__joinable_tables = {} self._current_path = () - self._primary_adapter=None self._only_load_props = None self._refresh_instance = None - + + self._adapter = self.select_mapper._clause_adapter + def _no_criterion(self, meth): q = self._clone() @@ -79,6 +80,7 @@ class Query(object): "criterion is being ignored.") % meth) q._from_obj = self.table + q._adapter = self.select_mapper._clause_adapter q._alias_ids = {} q._joinpoint = self.mapper q._statement = q._aliases = q._criterion = None @@ -357,7 +359,7 @@ class Query(object): q._params = q._params.copy() q._params.update(kwargs) return q - + def filter(self, criterion): """apply the given filtering criterion to the query and return the newly resulting ``Query`` @@ -370,12 +372,9 @@ class Query(object): if criterion is not None and not isinstance(criterion, sql.ClauseElement): raise exceptions.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string") - - if self._aliases is not None: - criterion = self._aliases.adapt_clause(criterion) - elif self.table not in self._get_joinable_tables(): - criterion = sql_util.ClauseAdapter(self._from_obj).traverse(criterion) - + if self._adapter is not None: + criterion = self._adapter.traverse(criterion) + q = self._no_statement("filter") if q._criterion is not None: q._criterion = q._criterion & criterion @@ -392,14 +391,16 @@ class Query(object): return self.filter(sql.and_(*clauses)) def _get_joinable_tables(self): - if self._from_obj not in self.__joinable_tables: + if not self._joinable_tables or self._joinable_tables[0] is not self._from_obj: currenttables = [self._from_obj] def visit_join(join): currenttables.append(join.left) currenttables.append(join.right) visitors.traverse(self._from_obj, visit_join=visit_join, traverse_options={'column_collections':False, 'aliased_selectables':False}) - self.__joinable_tables = {self._from_obj : currenttables} - return self.__joinable_tables[self._from_obj] + self._joinable_tables = (self._from_obj, currenttables) + return currenttables + else: + return self._joinable_tables[1] def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True): if start is None: @@ -408,7 +409,15 @@ class Query(object): clause = self._from_obj currenttables = self._get_joinable_tables() - adapt_criterion = self.table not in currenttables + + # determine if generated joins need to be aliased on the left + # hand side. + if self._adapter and not self._aliases: # at the beginning of a join, look at leftmost adapter + adapt_against = self._adapter.selectable + elif start.select_table is not start.mapped_table: # in the middle of a join, look for a polymorphic mapper + adapt_against = start.select_table + else: + adapt_against = None mapper = start alias = self._aliases @@ -421,35 +430,27 @@ class Query(object): if prop.secondary: if create_aliases: alias = mapperutil.PropertyAliasedClauses(prop, - prop.get_join(mapper, primary=True, secondary=False), - prop.get_join(mapper, primary=False, secondary=True), + prop.primary_join_against(mapper, adapt_against), + prop.secondary_join_against(mapper), alias ) crit = alias.primaryjoin - if adapt_criterion: - crit = sql_util.ClauseAdapter(clause).traverse(crit) clause = clause.join(alias.secondary, crit, isouter=outerjoin).join(alias.alias, alias.secondaryjoin, isouter=outerjoin) else: - crit = prop.get_join(mapper, primary=True, secondary=False) - if adapt_criterion: - crit = sql_util.ClauseAdapter(clause).traverse(crit) + crit = prop.primary_join_against(mapper, adapt_against) clause = clause.join(prop.secondary, crit, isouter=outerjoin) - clause = clause.join(prop.select_table, prop.get_join(mapper, primary=False), isouter=outerjoin) + clause = clause.join(prop.select_table, prop.secondary_join_against(mapper), isouter=outerjoin) else: if create_aliases: alias = mapperutil.PropertyAliasedClauses(prop, - prop.get_join(mapper, primary=True, secondary=False), + prop.primary_join_against(mapper, adapt_against), None, alias ) crit = alias.primaryjoin - if adapt_criterion: - crit = sql_util.ClauseAdapter(clause).traverse(crit) clause = clause.join(alias.alias, crit, isouter=outerjoin) else: - crit = prop.get_join(mapper) - if adapt_criterion: - crit = sql_util.ClauseAdapter(clause).traverse(crit) + crit = prop.primary_join_against(mapper, adapt_against) clause = clause.join(prop.select_table, crit, isouter=outerjoin) elif not create_aliases and prop.secondary is not None and prop.secondary not in currenttables: # TODO: this check is not strong enough for different paths to the same endpoint which @@ -458,6 +459,9 @@ class Query(object): mapper = prop.mapper + if mapper.select_table is not mapper.mapped_table: + adapt_against = mapper.select_table + if create_aliases: return (clause, mapper, alias) else: @@ -539,9 +543,9 @@ class Query(object): q = self._no_statement("order_by") - if self._aliases is not None: + if self._adapter: criterion = [expression._literal_as_text(o) for o in util.to_list(criterion) or []] - criterion = self._aliases.adapt_list(criterion) + criterion = self._adapter.copy_and_process(criterion) if q._order_by is False: q._order_by = util.to_list(criterion) @@ -568,9 +572,8 @@ class Query(object): if criterion is not None and not isinstance(criterion, sql.ClauseElement): raise exceptions.ArgumentError("having() argument must be of type sqlalchemy.sql.ClauseElement or string") - - if self._aliases is not None: - criterion = self._aliases.adapt_clause(criterion) + if self._adapter is not None: + criterion = self._adapter.traverse(criterion) q = self._no_statement("having") if q._having is not None: @@ -605,6 +608,13 @@ class Query(object): q._from_obj = clause q._joinpoint = mapper q._aliases = aliases + + if aliases: + q._adapter = sql_util.ClauseAdapter(aliases.alias).copy_and_chain(q._adapter) + else: + select_mapper = mapper.get_select_mapper() + if select_mapper._clause_adapter: + q._adapter = select_mapper._clause_adapter.copy_and_chain(q._adapter) a = aliases while a is not None: @@ -629,6 +639,8 @@ class Query(object): q = self._no_statement("reset_joinpoint") q._joinpoint = q.mapper q._aliases = None + if q.table not in q._get_joinable_tables(): + q._adapter = sql_util.ClauseAdapter(q._from_obj, equivalents=q.mapper._equivalent_columns) return q @@ -651,6 +663,9 @@ class Query(object): from_obj = from_obj.alias() new._from_obj = from_obj + + if new.table not in new._get_joinable_tables(): + new._adapter = sql_util.ClauseAdapter(new._from_obj, equivalents=new.mapper._equivalent_columns) return new def __getitem__(self, item): @@ -787,9 +802,9 @@ class Query(object): mappers_or_columns = tuple(self._entities) + mappers_or_columns tuples = bool(mappers_or_columns) - if self._primary_adapter: + if context.row_adapter: def main(context, row): - return self.select_mapper._instance(context, self._primary_adapter(row), None, + return self.select_mapper._instance(context, context.row_adapter(row), None, extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance ) else: @@ -957,17 +972,18 @@ class Query(object): from_obj = self._from_obj - # indicates if the "from" clause of the query does not include - # the normally mapped table, i.e. the user issued select_from(somestatement) - # or similar. all clauses which derive from the mapped table will need to - # be adapted to be relative to the user-supplied selectable. - adapt_criterion = self.table not in self._get_joinable_tables() - - # adapt for poylmorphic mapper - # TODO: generalize the polymorphic mapper adaption to that of the select_from() adaption - if not adapt_criterion and whereclause is not None and (self.mapper is not self.select_mapper): - whereclause = sql_util.ClauseAdapter(from_obj, equivalents=self.select_mapper._equivalent_columns).traverse(whereclause) + # if the query's ClauseAdapter is present, and its + # specifically adapting against a modified "select_from" + # argument, apply adaptation to the + # individually selected columns as well as "eager" clauses added; + # otherwise its currently not needed + if self._adapter and self.table not in self._get_joinable_tables(): + adapter = self._adapter + else: + adapter = None + adapter = self._adapter + # TODO: mappers added via add_entity(), adapt their queries also, # if those mappers are polymorphic @@ -1029,7 +1045,9 @@ class Query(object): for o in order_by: cf.update(sql_util.find_columns(o)) - if adapt_criterion: + if adapter: + # TODO: make usage of the ClauseAdapter here to create the list + # of primary columns context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns] cf = [from_obj.corresponding_column(c) or c for c in cf] @@ -1037,7 +1055,7 @@ class Query(object): s3 = s2.alias() - self._primary_adapter = mapperutil.create_row_adapter(s3, self.table) + context.row_adapter = mapperutil.create_row_adapter(s3, self.table) statement = sql.select([s3] + context.secondary_columns, for_update=for_update, use_labels=True) @@ -1050,17 +1068,16 @@ class Query(object): statement.append_order_by(*context.eager_order_by) else: - if adapt_criterion: + if adapter: + # TODO: make usage of the ClauseAdapter here to create row adapter, list + # of primary columns context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns] - self._primary_adapter = mapperutil.create_row_adapter(from_obj, self.table) + context.row_adapter = mapperutil.create_row_adapter(from_obj, self.table) - if adapt_criterion or self._distinct: + if self._distinct: if order_by: order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []] - if adapt_criterion: - order_by = sql_util.ClauseAdapter(from_obj).copy_and_process(order_by) - if self._distinct and order_by: cf = util.Set() for o in order_by: @@ -1071,13 +1088,13 @@ class Query(object): statement = sql.select(context.primary_columns + context.secondary_columns, whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, order_by=util.to_list(order_by), **self._select_args()) if context.eager_joins: - if adapt_criterion: - context.eager_joins = sql_util.ClauseAdapter(from_obj).traverse(context.eager_joins) + if adapter: + context.eager_joins = adapter.traverse(context.eager_joins) statement.append_from(context.eager_joins, _copy_collection=False) if context.eager_order_by: - if adapt_criterion: - context.eager_order_by = sql_util.ClauseAdapter(from_obj).copy_and_process(context.eager_order_by) + if adapter: + context.eager_order_by = adapter.copy_and_process(context.eager_order_by) statement.append_order_by(*context.eager_order_by) context.statement = statement @@ -1103,6 +1120,7 @@ class Query(object): return self._alias_ids[alias_id] except KeyError: raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % alias_id) + if isinstance(m, type): m = mapper.class_mapper(m) if isinstance(m, mapper.Mapper): @@ -1369,6 +1387,7 @@ class QueryContext(object): self.session = query.session self.extension = query._extension self.statement = None + self.row_adapter = None self.populate_existing = query._populate_existing self.version_check = query._version_check self.only_load_props = query._only_load_props diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index a715d924a1..908c43feb1 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -519,10 +519,7 @@ class EagerLoader(AbstractRelationLoader): if context.eager_joins: towrap = context.eager_joins else: - if isinstance(context.from_clause, sql.Join): - towrap = context.from_clause - else: - towrap = localparent.mapped_table + towrap = context.from_clause # create AliasedClauses object to build up the eager query. this is cached after 1st creation. try: diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 7473609d74..4f2ab5444a 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -236,10 +236,6 @@ class PropertyAliasedClauses(AliasedClauses): super(PropertyAliasedClauses, self).__init__(prop.select_table) self.parentclauses = parentclauses - if parentclauses is not None: - self.path = build_path(prop.parent, prop.key, parentclauses.path) - else: - self.path = build_path(prop.parent, prop.key) self.prop = prop @@ -261,6 +257,7 @@ class PropertyAliasedClauses(AliasedClauses): aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, exclude=prop.remote_side)) else: aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side) + self.primaryjoin = aliasizer.traverse(primaryjoin, clone=True) self.secondary = None self.secondaryjoin = None @@ -273,9 +270,6 @@ class PropertyAliasedClauses(AliasedClauses): mapper = property(lambda self:self.prop.mapper) table = property(lambda self:self.prop.select_table) - def __str__(self): - return "->".join([str(s) for s in self.path]) - def instance_str(instance): """Return a string describing an instance.""" diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index b45c0425c8..c2ac26557e 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -186,6 +186,25 @@ class ClauseAdapter(AbstractClauseProcessor): self.exclude = exclude self.equivalents = equivalents + def copy_and_chain(self, adapter): + """create a copy of this adapter and chain to the given adapter. + + currently this adapter must be unchained to start, raises + an exception if it's already chained. + + Does not modify the given adapter. + """ + + if adapter is None: + return self + + if hasattr(self, '_next_acp') or hasattr(self, '_next'): + raise NotImplementedError("Can't chain_to on an already chained ClauseAdapter (yet)") + + ca = ClauseAdapter(self.selectable, self.include, self.exclude, self.equivalents) + ca._next_acp = adapter + return ca + def convert_element(self, col): if isinstance(col, expression.FromClause): if self.selectable.is_derived_from(col): diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py index e42ef5cb81..f35fbcbfc2 100644 --- a/test/orm/eager_relations.py +++ b/test/orm/eager_relations.py @@ -195,6 +195,10 @@ class EagerTest(FixtureTest): assert fixtures.item_keyword_result[0:2] == q.join('keywords').filter(keywords.c.name == 'red').all() self.assert_sql_count(testing.db, go, 1) + def go(): + assert fixtures.item_keyword_result[0:2] == q.join('keywords', aliased=True).filter(keywords.c.name == 'red').all() + self.assert_sql_count(testing.db, go, 1) + def test_eager_option(self): mapper(Keyword, keywords) diff --git a/test/orm/inheritance/query.py b/test/orm/inheritance/query.py index 698df33fa7..2a15ae1b01 100644 --- a/test/orm/inheritance/query.py +++ b/test/orm/inheritance/query.py @@ -93,7 +93,7 @@ class PolymorphicQueryTest(ORMTest): mapper(Paperwork, paperwork) def insert_data(self): - global all_employees, c1_employees, c2_employees, e1, e2, b1, m1, e3 + global all_employees, c1_employees, c2_employees, e1, e2, b1, m1, e3, c1, c2 c1 = Company(name="MegaCorp, Inc.") c2 = Company(name="Elbonia, Inc.") @@ -114,7 +114,9 @@ class PolymorphicQueryTest(ORMTest): ]) c1.employees = [e1, e2, b1, m1] - e3 = Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer") + e3 = Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer", paperwork=[ + Paperwork(description='elbonian missive #3') + ]) c2.employees = [e3] sess = create_session() sess.save(c1) @@ -127,9 +129,6 @@ class PolymorphicQueryTest(ORMTest): c2_employees = [e3] def test_filter_on_subclass(self): - print Manager.person_id == Engineer.person_id - print Manager.c.person_id == Engineer.c.person_id - sess = create_session() self.assertEquals(sess.query(Engineer).all()[0], Engineer(name="dilbert")) @@ -142,12 +141,74 @@ class PolymorphicQueryTest(ORMTest): self.assertEquals(sess.query(Manager).filter(Manager.person_id==b1.person_id).one(), Boss(name="pointy haired boss")) self.assertEquals(sess.query(Boss).filter(Boss.person_id==b1.person_id).one(), Boss(name="pointy haired boss")) + + def test_join_from_polymorphic(self): + sess = create_session() - def test_load_all(self): + for aliased in (True, False): + self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1]) + + self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1]) + + self.assertEquals(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1]) + + self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) + + def test_join_to_polymorphic(self): + sess = create_session() + self.assertEquals(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2) + + self.assertEquals(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2) + + def test_join_through_polymorphic(self): + sess = create_session() + + for aliased in (True, False): + self.assertEquals( + sess.query(Company).\ + join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), + [c1] + ) + + self.assertEquals( + sess.query(Company).\ + join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#%')).all(), + [c1, c2] + ) + + self.assertEquals( + sess.query(Company).\ + join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#2%')).all(), + [c1] + ) + + self.assertEquals( + sess.query(Company).\ + join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#%')).all(), + [c1, c2] + ) + + self.assertEquals( + sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\ + join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), + [c1] + ) + + self.assertEquals( + sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\ + join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#%')).all(), + [c1, c2] + ) + + def test_filter_on_baseclass(self): sess = create_session() self.assertEquals(sess.query(Person).all(), all_employees) + self.assertEquals(sess.query(Person).first(), all_employees[0]) + + self.assertEquals(sess.query(Person).filter(Person.person_id==e2.person_id).one(), e2) + if __name__ == "__main__": testenv.main() -- 2.47.3