From 075eb9076b7bf61351f4ee0465babe4e90e57d20 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 1 Mar 2008 01:46:23 +0000 Subject: [PATCH] - fixed bug whereby session.expire() attributes were not loading on an polymorphically-mapped instance mapped by a select_table mapper. - added query.with_polymorphic() - specifies a list of classes which descend from the base class, which will be added to the FROM clause of the query. Allows subclasses to be used within filter() criterion as well as eagerly loads the attributes of those subclasses. - deprecated Query methods apply_sum(), apply_max(), apply_min(), apply_avg(). Better methodologies are coming.... --- CHANGES | 15 ++- lib/sqlalchemy/orm/mapper.py | 1 + lib/sqlalchemy/orm/query.py | 188 ++++++++++++++++++++++++---------- test/orm/generative.py | 2 + test/orm/inheritance/query.py | 66 +++++++++++- test/orm/query.py | 1 + 6 files changed, 219 insertions(+), 54 deletions(-) diff --git a/CHANGES b/CHANGES index 501db5d286..3e5d36c20f 100644 --- a/CHANGES +++ b/CHANGES @@ -25,7 +25,17 @@ CHANGES work properly with self-referential relations - the clause inside the EXISTS is aliased on the "remote" side to distinguish it from the parent table. - + + - fixed bug whereby session.expire() attributes were not + loading on an polymorphically-mapped instance mapped + by a select_table mapper. + + - added query.with_polymorphic() - specifies a list + of classes which descend from the base class, which will + be added to the FROM clause of the query. Allows subclasses + to be used within filter() criterion as well as eagerly loads + the attributes of those subclasses. + - Your cries have been heard: removing a pending item from an attribute or collection with delete-orphan expunges the item from the session; no FlushError is raised. Note that if you @@ -35,6 +45,9 @@ CHANGES - Fixed potential generative bug when the same Query was used to generate multiple Query objects using join(). + - deprecated Query methods apply_sum(), apply_max(), apply_min(), + apply_avg(). Better methodologies are coming.... + - Added a new "higher level" operator called "of_type()": used in join() as well as with any() and has(), qualifies the subclass which will be used in filter criterion, e.g.: diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 62067fc358..297d222466 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1579,6 +1579,7 @@ def _load_scalar_attributes(instance, attribute_names): identity_key = state.dict['_instance_key'] else: identity_key = mapper._identity_key_from_state(state) + if session.query(mapper)._get(identity_key, refresh_instance=state, only_load_props=attribute_names) is None: raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % instance_str(instance)) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 46f986d14e..ebe62e915c 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -33,14 +33,12 @@ class Query(object): """Encapsulates the object-fetching operations provided by Mappers.""" def __init__(self, class_or_mapper, session=None, entity_name=None): - self.mapper = _class_to_mapper(class_or_mapper, entity_name=entity_name) - self.select_mapper = self.mapper.get_select_mapper().compile() - + self._init_mapper(_class_to_mapper(class_or_mapper, entity_name=entity_name)) self._session = session self._with_options = [] self._lockmode = None - self._extension = self.mapper.extension + self._entities = [] self._order_by = False self._group_by = False @@ -54,24 +52,41 @@ class Query(object): self._joinable_tables = None self._having = None self._column_aggregate = None - self._joinpoint = self.mapper self._aliases = None self._alias_ids = {} - self._from_obj = self.table self._populate_existing = False self._version_check = False self._autoflush = True - self._eager_loaders = util.Set(chain(*[mp._eager_loaders for mp in [m for m in self.mapper.iterate_to_root()]])) + self._attributes = {} self._current_path = () self._only_load_props = None self._refresh_instance = None - + + def _init_mapper(self, mapper, select_mapper=None): + """populate all instance variables derived from this Query's mapper.""" + + self.mapper = mapper + self.select_mapper = select_mapper or self.mapper.get_select_mapper().compile() + self.table = self._from_obj = self.select_mapper.mapped_table + self._eager_loaders = util.Set(chain(*[mp._eager_loaders for mp in [m for m in self.mapper.iterate_to_root()]])) + self._extension = self.mapper.extension self._adapter = self.select_mapper._clause_adapter - + self._joinpoint = self.mapper + self._with_polymorphic = [] + def _no_criterion(self, meth): - q = self._clone() + return self._conditional_clone(meth, [self._no_criterion_condition]) + def _no_statement(self, meth): + return self._conditional_clone(meth, [self._no_statement_condition]) + + def _new_base_mapper(self, mapper, meth): + q = self._conditional_clone(meth, [self._no_criterion_condition]) + q._init_mapper(mapper, mapper) + return q + + def _no_criterion_condition(self, q, meth): if q._criterion or q._statement or q._from_obj is not self.table: util.warn( ("Query.%s() being called on a Query with existing criterion; " @@ -83,16 +98,20 @@ class Query(object): q._joinpoint = self.mapper q._statement = q._aliases = q._criterion = None q._order_by = q._group_by = q._distinct = False - return q - - def _no_statement(self, meth): - q = self._clone() + + def _no_statement_condition(self, q, meth): if q._statement: raise exceptions.InvalidRequestError( ("Query.%s() being called on a Query with an existing full " "statement - can't apply criterion.") % meth) + + def _conditional_clone(self, methname=None, conditions=None): + q = self._clone() + if conditions: + for condition in conditions: + condition(q, methname) return q - + def _clone(self): q = Query.__new__(Query) q.__dict__ = self.__dict__.copy() @@ -104,7 +123,6 @@ class Query(object): else: return self._session - table = property(lambda s:s.select_mapper.mapped_table) primary_key_columns = property(lambda s:s.select_mapper.primary_key) session = property(_get_session) @@ -112,7 +130,63 @@ class Query(object): q = self._clone() q._current_path = path return q + + def with_polymorphic(self, cls_or_mappers, selectable=None): + """Load columns for descendant mappers of this Query's mapper. + + Using this method will ensure that each descendant mapper's + tables are included in the FROM clause, and will allow filter() + criterion to be used against those tables. The resulting + instances will also have those columns already loaded so that + no "post fetch" of those columns will be required. + + If this Query's mapper has a ``select_table`` argument, + with_polymorphic() overrides it; the FROM clause will be against + the local table of the base mapper outer joined with the local + tables of each specified descendant mapper (unless ``selectable`` + is specified). + + ``cls_or_mappers`` is a single class or mapper, or list of class/mappers, + which inherit from this Query's mapper. Alternatively, it + may also be the string ``'*'``, in which case all descending + mappers will be added to the FROM clause. + + ``selectable`` is a table or select() statement that will + be used in place of the generated FROM clause. This argument + is required if any of the desired mappers use concrete table + inheritance, since SQLAlchemy currently cannot generate UNIONs + among tables automatically. If used, the ``selectable`` + argument must represent the full set of tables and columns mapped + by every desired mapper. Otherwise, the unaccounted mapped columns + will result in their table being appended directly to the FROM + clause which will usually lead to incorrect results. + + """ + + q = self._new_base_mapper(self.mapper, 'with_polymorphic') + if cls_or_mappers == '*': + cls_or_mappers = self.mapper.polymorphic_iterator() + else: + cls_or_mappers = util.to_list(cls_or_mappers) + + if selectable: + q = q.select_from(selectable) + + for cls_or_mapper in cls_or_mappers: + poly_mapper = _class_to_mapper(cls_or_mapper) + if poly_mapper is self.mapper: + continue + + q._with_polymorphic.append(poly_mapper) + if not selectable: + if poly_mapper.concrete: + raise exceptions.InvalidRequestError("'with_polymorphic()' requires 'selectable' argument when concrete-inheriting mappers are used.") + elif not poly_mapper.single: + q._from_obj = q._from_obj.outerjoin(poly_mapper.local_table, poly_mapper.inherit_condition) + + return q + def yield_per(self, count): """Yield only ``count`` rows at a time. @@ -412,6 +486,8 @@ class Query(object): # hand side. if self._adapter and not self._aliases: # at the beginning of a join, look at leftmost adapter adapt_against = self._adapter.selectable + elif start is self.select_mapper: # or if its our base mapper, go against our base table + adapt_against = self.table elif start.select_table is not start.mapped_table: # in the middle of a join, look for a polymorphic mapper adapt_against = start.select_table else: @@ -444,7 +520,7 @@ class Query(object): raise exceptions.InvalidRequestError("Selectable '%s' is not derived from '%s'" % (use_selectable.description, prop.mapper.mapped_table.description)) if not isinstance(use_selectable, expression.Alias): use_selectable = use_selectable.alias() - + if prop._is_self_referential() and not create_aliases and not use_selectable: raise exceptions.InvalidRequestError("Self-referential query on '%s' property requires create_aliases=True argument." % str(prop)) @@ -503,24 +579,32 @@ class Query(object): def apply_min(self, col): """apply the SQL ``min()`` function against the given column to the query and return the newly resulting ``Query``. + + DEPRECATED. """ return self._generative_col_aggregate(col, sql.func.min) def apply_max(self, col): """apply the SQL ``max()`` function against the given column to the query and return the newly resulting ``Query``. + + DEPRECATED. """ return self._generative_col_aggregate(col, sql.func.max) def apply_sum(self, col): """apply the SQL ``sum()`` function against the given column to the query and return the newly resulting ``Query``. + + DEPRECATED. """ return self._generative_col_aggregate(col, sql.func.sum) def apply_avg(self, col): """apply the SQL ``avg()`` function against the given column to the query and return the newly resulting ``Query``. + + DEPRECATED. """ return self._generative_col_aggregate(col, sql.func.avg) @@ -852,6 +936,11 @@ class Query(object): context.runid = _new_runid() + # for with_polymorphic, instruct descendant mappers that they + # don't need to post-fetch anything + for m in self._with_polymorphic: + context.attributes[('polymorphic_fetch', m)] = (self.select_mapper, []) + mappers_or_columns = tuple(self._entities) + mappers_or_columns tuples = bool(mappers_or_columns) @@ -950,12 +1039,17 @@ class Query(object): ident = util.to_list(ident) q = self + + # dont use 'polymorphic' mapper if we are refreshing an instance + if refresh_instance and q.select_mapper is not q.mapper: + q = q._new_base_mapper(q.mapper, '_get') + if ident is not None: q = q._no_criterion('get') params = {} - (_get_clause, _get_params) = self.select_mapper._get_clause + (_get_clause, _get_params) = q.select_mapper._get_clause q = q.filter(_get_clause) - for i, primary_key in enumerate(self.primary_key_columns): + for i, primary_key in enumerate(q.primary_key_columns): try: params[_get_params[primary_key].key] = ident[i] except IndexError: @@ -1027,25 +1121,10 @@ class Query(object): return context whereclause = self._criterion - from_obj = self._from_obj - - # if the query's ClauseAdapter is present, and its - # specifically adapting against a modified "select_from" - # argument, apply adaptation to the - # individually selected columns as well as "eager" clauses added; - # otherwise its currently not needed - if self._adapter and self.table not in self._get_joinable_tables(): - adapter = self._adapter - else: - adapter = None - adapter = self._adapter - - # TODO: mappers added via add_entity(), adapt their queries also, - # if those mappers are polymorphic - order_by = self._order_by + if order_by is False: order_by = self.select_mapper.order_by if order_by is False: @@ -1055,22 +1134,31 @@ class Query(object): if from_obj.default_order_by() is not None: order_by = from_obj.default_order_by() - try: - for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[self._lockmode] - except KeyError: - raise exceptions.ArgumentError("Unknown lockmode '%s'" % self._lockmode) - + if self._lockmode: + try: + for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[self._lockmode] + except KeyError: + raise exceptions.ArgumentError("Unknown lockmode '%s'" % self._lockmode) + else: + for_update = False + # if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so # that we only load the appropriate types - if self.select_mapper.single and self.select_mapper.polymorphic_on is not None and self.select_mapper.polymorphic_identity is not None: + if self.select_mapper.single and self.select_mapper.inherits is not None and self.select_mapper.polymorphic_on is not None and self.select_mapper.polymorphic_identity is not None: whereclause = sql.and_(whereclause, self.select_mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.select_mapper.polymorphic_iterator()])) context.from_clause = from_obj - # give all the attached properties a chance to modify the query - # TODO: doing this off the select_mapper. if its the polymorphic mapper, then - # it has no relations() on it. should we compile those too into the query ? (i.e. eagerloads) - for value in self.select_mapper.iterate_properties: + # TODO: compile eagerloads from select_mapper if polymorphic ? [ticket:917] + if self._with_polymorphic: + props = util.Set() + for m in [self.select_mapper] + self._with_polymorphic: + for value in m.iterate_properties: + props.add(value) + else: + props = self.select_mapper.iterate_properties + + for value in props: if self._only_load_props and value.key not in self._only_load_props: continue context.exec_with_path(self.select_mapper, value.key, value.setup, context, only_load_props=self._only_load_props) @@ -1091,12 +1179,9 @@ class Query(object): # eager loaders are present, and the SELECT has limiting criterion # produce a "wrapped" selectable. - # ensure all 'order by' elements are ClauseElement instances - # (since they will potentially be aliased) # locate all embedded Column clauses so they can be added to the # "inner" select statement where they'll be available to the enclosing # statement's "order by" - cf = util.Set() if order_by: order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []] @@ -1105,7 +1190,7 @@ class Query(object): if adapter: # TODO: make usage of the ClauseAdapter here to create the list - # of primary columns + # of primary columns ? context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns] cf = [from_obj.corresponding_column(c) or c for c in cf] @@ -1128,7 +1213,7 @@ class Query(object): else: if adapter: # TODO: make usage of the ClauseAdapter here to create row adapter, list - # of primary columns + # of primary columns ? context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns] context.row_adapter = mapperutil.create_row_adapter(from_obj, self.table) @@ -1425,13 +1510,12 @@ class Query(object): return self._legacy_filter_by(*args, **params).one() - for deprecated_method in ('list', 'scalar', 'count_by', 'select_whereclause', 'get_by', 'select_by', 'join_by', 'selectfirst', 'selectone', 'select', 'execute', 'select_statement', 'select_text', 'join_to', 'join_via', 'selectfirst_by', - 'selectone_by'): + 'selectone_by', 'apply_max', 'apply_min', 'apply_avg', 'apply_sum'): setattr(Query, deprecated_method, util.deprecated(getattr(Query, deprecated_method), add_deprecation_to_docstring=False)) diff --git a/test/orm/generative.py b/test/orm/generative.py index 9967f34f7e..db8e313e67 100644 --- a/test/orm/generative.py +++ b/test/orm/generative.py @@ -53,6 +53,7 @@ class GenerativeQueryTest(TestBase): assert list(query[-5:]) == orig[-5:] assert query[10:20][5] == orig[10:20][5] + @testing.uses_deprecated('Call to deprecated function apply_max') def test_aggregate(self): sess = create_session(bind=testing.db) query = sess.query(Foo) @@ -77,6 +78,7 @@ class GenerativeQueryTest(TestBase): assert round(avg, 1) == 14.5 @testing.fails_on('firebird', 'mssql') + @testing.uses_deprecated('Call to deprecated function apply_avg') def test_aggregate_3(self): query = create_session(bind=testing.db).query(Foo) diff --git a/test/orm/inheritance/query.py b/test/orm/inheritance/query.py index 3571480292..7d7b8b9d91 100644 --- a/test/orm/inheritance/query.py +++ b/test/orm/inheritance/query.py @@ -194,6 +194,19 @@ def make_test(select_type): self.assertEquals(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1]) self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) + + def test_join_from_with_polymorphic(self): + sess = create_session() + + for aliased in (True, False): + sess.clear() + self.assertEquals(sess.query(Person).with_polymorphic(Manager).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1]) + + sess.clear() + self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1]) + + sess.clear() + self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) def test_join_to_polymorphic(self): sess = create_session() @@ -223,7 +236,58 @@ def make_test(select_type): sess.query(Company).filter(Company.employees.any(and_(Engineer.primary_language=='cobol', people.c.person_id==engineers.c.person_id))).one(), c2 ) - + + def test_expire(self): + """test that individual column refresh doesn't get tripped up by the select_table mapper""" + + sess = create_session() + m1 = sess.query(Manager).filter(Manager.name=='dogbert').one() + sess.expire(m1) + assert m1.status == 'regular manager' + + m2 = sess.query(Manager).filter(Manager.name=='pointy haired boss').one() + sess.expire(m2, ['manager_name', 'golf_swing']) + assert m2.golf_swing=='fore' + + def test_with_polymorphic(self): + + sess = create_session() + + # compare to entities without related collections to prevent additional lazy SQL from firing on + # loaded entities + emps_without_relations = [ + Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer"), + Engineer(name="wally", engineer_name="wally", primary_language="c++", status="regular engineer"), + Boss(name="pointy haired boss", golf_swing="fore", manager_name="pointy", status="da boss"), + Manager(name="dogbert", manager_name="dogbert", status="regular manager"), + Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer") + ] + + def go(): + self.assertEquals(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1]) + self.assert_sql_count(testing.db, go, 1) + + sess.clear() + def go(): + self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations) + self.assert_sql_count(testing.db, go, 1) + + sess.clear() + def go(): + self.assertEquals(sess.query(Person).with_polymorphic(Engineer).all(), emps_without_relations) + self.assert_sql_count(testing.db, go, 3) + + sess.clear() + def go(): + self.assertEquals(sess.query(Person).with_polymorphic(Engineer, people.outerjoin(engineers)).all(), emps_without_relations) + self.assert_sql_count(testing.db, go, 3) + + sess.clear() + def go(): + # limit the polymorphic join down to just "Person", overriding select_table + self.assertEquals(sess.query(Person).with_polymorphic(Person).all(), emps_without_relations) + self.assert_sql_count(testing.db, go, 6) + def test_join_to_subclass(self): sess = create_session() diff --git a/test/orm/query.py b/test/orm/query.py index 41ae444614..62bb99a323 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -389,6 +389,7 @@ class AggregateTest(QueryTest): orders = sess.query(Order).filter(Order.id.in_([2, 3, 4])) assert orders.sum(Order.user_id * Order.address_id) == 79 + @testing.uses_deprecated('Call to deprecated function apply_sum') def test_apply(self): sess = create_session() assert sess.query(Order).apply_sum(Order.user_id * Order.address_id).filter(Order.id.in_([2, 3, 4])).one() == 79 -- 2.47.3