]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fixed bug whereby session.expire() attributes were not
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Mar 2008 01:46:23 +0000 (01:46 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Mar 2008 01:46:23 +0000 (01:46 +0000)
loading on an polymorphically-mapped instance mapped
by a select_table mapper.

- added query.with_polymorphic() - specifies a list
of classes which descend from the base class, which will
be added to the FROM clause of the query.  Allows subclasses
to be used within filter() criterion as well as eagerly loads
the attributes of those subclasses.

- deprecated Query methods apply_sum(), apply_max(), apply_min(),
apply_avg().  Better methodologies are coming....

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
test/orm/generative.py
test/orm/inheritance/query.py
test/orm/query.py

diff --git a/CHANGES b/CHANGES
index 501db5d286f1b8a2894b210d29afb21743c9959a..3e5d36c20fba6ddb765f3253d25806f09adcadd3 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -25,7 +25,17 @@ CHANGES
       work properly with self-referential relations - the clause
       inside the EXISTS is aliased on the "remote" side to
       distinguish it from the parent table.
-
+    
+    - fixed bug whereby session.expire() attributes were not
+      loading on an polymorphically-mapped instance mapped 
+      by a select_table mapper.
+      
+    - added query.with_polymorphic() - specifies a list
+      of classes which descend from the base class, which will
+      be added to the FROM clause of the query.  Allows subclasses
+      to be used within filter() criterion as well as eagerly loads
+      the attributes of those subclasses.
+    
     - Your cries have been heard: removing a pending item from an
       attribute or collection with delete-orphan expunges the item
       from the session; no FlushError is raised.  Note that if you
@@ -35,6 +45,9 @@ CHANGES
     - Fixed potential generative bug when the same Query was used
       to generate multiple Query objects using join().
 
+    - deprecated Query methods apply_sum(), apply_max(), apply_min(),
+      apply_avg().  Better methodologies are coming....
+      
     - Added a new "higher level" operator called "of_type()": used
       in join() as well as with any() and has(), qualifies the
       subclass which will be used in filter criterion, e.g.:
index 62067fc358abdfaae8662e2ed9cbd8531c7f033a..297d222466b334b420407fa7edbbefd405255799 100644 (file)
@@ -1579,6 +1579,7 @@ def _load_scalar_attributes(instance, attribute_names):
         identity_key = state.dict['_instance_key']
     else:
         identity_key = mapper._identity_key_from_state(state)
+
     if session.query(mapper)._get(identity_key, refresh_instance=state, only_load_props=attribute_names) is None:
         raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % instance_str(instance))
 
index 46f986d14e0b012d8deac9da3aa33cea89ab2d54..ebe62e915c4b73cce5c106b78fa4c337adf8c9b1 100644 (file)
@@ -33,14 +33,12 @@ class Query(object):
     """Encapsulates the object-fetching operations provided by Mappers."""
 
     def __init__(self, class_or_mapper, session=None, entity_name=None):
-        self.mapper = _class_to_mapper(class_or_mapper, entity_name=entity_name)
-        self.select_mapper = self.mapper.get_select_mapper().compile()
-
+        self._init_mapper(_class_to_mapper(class_or_mapper, entity_name=entity_name))
         self._session = session
 
         self._with_options = []
         self._lockmode = None
-        self._extension = self.mapper.extension
+        
         self._entities = []
         self._order_by = False
         self._group_by = False
@@ -54,24 +52,41 @@ class Query(object):
         self._joinable_tables = None
         self._having = None
         self._column_aggregate = None
-        self._joinpoint = self.mapper
         self._aliases = None
         self._alias_ids = {}
-        self._from_obj = self.table
         self._populate_existing = False
         self._version_check = False
         self._autoflush = True
-        self._eager_loaders = util.Set(chain(*[mp._eager_loaders for mp in [m for m in self.mapper.iterate_to_root()]]))
+        
         self._attributes = {}
         self._current_path = ()
         self._only_load_props = None
         self._refresh_instance = None
-
+    
+    def _init_mapper(self, mapper, select_mapper=None):
+        """populate all instance variables derived from this Query's mapper."""
+        
+        self.mapper = mapper
+        self.select_mapper = select_mapper or self.mapper.get_select_mapper().compile()
+        self.table = self._from_obj = self.select_mapper.mapped_table
+        self._eager_loaders = util.Set(chain(*[mp._eager_loaders for mp in [m for m in self.mapper.iterate_to_root()]]))
+        self._extension = self.mapper.extension
         self._adapter = self.select_mapper._clause_adapter
-
+        self._joinpoint = self.mapper
+        self._with_polymorphic = []
+    
     def _no_criterion(self, meth):
-        q = self._clone()
+        return self._conditional_clone(meth, [self._no_criterion_condition])
 
+    def _no_statement(self, meth):
+        return self._conditional_clone(meth, [self._no_statement_condition])
+    
+    def _new_base_mapper(self, mapper, meth):
+        q = self._conditional_clone(meth, [self._no_criterion_condition])
+        q._init_mapper(mapper, mapper)
+        return q
+        
+    def _no_criterion_condition(self, q, meth):
         if q._criterion or q._statement or q._from_obj is not self.table:
             util.warn(
                 ("Query.%s() being called on a Query with existing criterion; "
@@ -83,16 +98,20 @@ class Query(object):
         q._joinpoint = self.mapper
         q._statement = q._aliases = q._criterion = None
         q._order_by = q._group_by = q._distinct = False
-        return q
-
-    def _no_statement(self, meth):
-        q = self._clone()
+    
+    def _no_statement_condition(self, q, meth):
         if q._statement:
             raise exceptions.InvalidRequestError(
                 ("Query.%s() being called on a Query with an existing full "
                  "statement - can't apply criterion.") % meth)
+    
+    def _conditional_clone(self, methname=None, conditions=None):
+        q = self._clone()
+        if conditions:
+            for condition in conditions:
+                condition(q, methname)
         return q
-
+        
     def _clone(self):
         q = Query.__new__(Query)
         q.__dict__ = self.__dict__.copy()
@@ -104,7 +123,6 @@ class Query(object):
         else:
             return self._session
 
-    table = property(lambda s:s.select_mapper.mapped_table)
     primary_key_columns = property(lambda s:s.select_mapper.primary_key)
     session = property(_get_session)
 
@@ -112,7 +130,63 @@ class Query(object):
         q = self._clone()
         q._current_path = path
         return q
+    
+    def with_polymorphic(self, cls_or_mappers, selectable=None):
+        """Load columns for descendant mappers of this Query's mapper.
+        
+        Using this method will ensure that each descendant mapper's
+        tables are included in the FROM clause, and will allow filter() 
+        criterion to be used against those tables.  The resulting 
+        instances will also have those columns already loaded so that
+        no "post fetch" of those columns will be required.
+        
+        If this Query's mapper has a ``select_table`` argument, 
+        with_polymorphic() overrides it; the FROM clause will be against
+        the local table of the base mapper outer joined with the local
+        tables of each specified descendant mapper (unless ``selectable``
+        is specified).
+        
+        ``cls_or_mappers`` is a single class or mapper, or list of class/mappers,
+        which inherit from this Query's mapper.  Alternatively, it
+        may also be the string ``'*'``, in which case all descending 
+        mappers will be added to the FROM clause.
+        
+        ``selectable`` is a table or select() statement that will 
+        be used in place of the generated FROM clause.  This argument
+        is required if any of the desired mappers use concrete table 
+        inheritance, since SQLAlchemy currently cannot generate UNIONs 
+        among tables automatically.  If used, the ``selectable`` 
+        argument must represent the full set of tables and columns mapped 
+        by every desired mapper.  Otherwise, the unaccounted mapped columns
+        will result in their table being appended directly to the FROM 
+        clause which will usually lead to incorrect results.
+
+        """
+        
+        q = self._new_base_mapper(self.mapper, 'with_polymorphic')
 
+        if cls_or_mappers == '*':
+            cls_or_mappers = self.mapper.polymorphic_iterator()
+        else:
+            cls_or_mappers = util.to_list(cls_or_mappers)
+        
+        if selectable:
+            q = q.select_from(selectable)
+                
+        for cls_or_mapper in cls_or_mappers:
+            poly_mapper = _class_to_mapper(cls_or_mapper)
+            if poly_mapper is self.mapper:
+                continue
+
+            q._with_polymorphic.append(poly_mapper)
+            if not selectable:
+                if poly_mapper.concrete:
+                    raise exceptions.InvalidRequestError("'with_polymorphic()' requires 'selectable' argument when concrete-inheriting mappers are used.")
+                elif not poly_mapper.single:
+                    q._from_obj = q._from_obj.outerjoin(poly_mapper.local_table, poly_mapper.inherit_condition)
+
+        return q
+        
     def yield_per(self, count):
         """Yield only ``count`` rows at a time.
 
@@ -412,6 +486,8 @@ class Query(object):
         # hand side.
         if self._adapter and not self._aliases:  # at the beginning of a join, look at leftmost adapter
             adapt_against = self._adapter.selectable
+        elif start is self.select_mapper:  # or if its our base mapper, go against our base table
+            adapt_against = self.table
         elif start.select_table is not start.mapped_table: # in the middle of a join, look for a polymorphic mapper
             adapt_against = start.select_table
         else:
@@ -444,7 +520,7 @@ class Query(object):
                     raise exceptions.InvalidRequestError("Selectable '%s' is not derived from '%s'" % (use_selectable.description, prop.mapper.mapped_table.description))
                 if not isinstance(use_selectable, expression.Alias):
                     use_selectable = use_selectable.alias()
-
+            
             if prop._is_self_referential() and not create_aliases and not use_selectable:
                 raise exceptions.InvalidRequestError("Self-referential query on '%s' property requires create_aliases=True argument." % str(prop))
 
@@ -503,24 +579,32 @@ class Query(object):
     def apply_min(self, col):
         """apply the SQL ``min()`` function against the given column to the
         query and return the newly resulting ``Query``.
+        
+        DEPRECATED.
         """
         return self._generative_col_aggregate(col, sql.func.min)
 
     def apply_max(self, col):
         """apply the SQL ``max()`` function against the given column to the
         query and return the newly resulting ``Query``.
+
+        DEPRECATED.
         """
         return self._generative_col_aggregate(col, sql.func.max)
 
     def apply_sum(self, col):
         """apply the SQL ``sum()`` function against the given column to the
         query and return the newly resulting ``Query``.
+
+        DEPRECATED.
         """
         return self._generative_col_aggregate(col, sql.func.sum)
 
     def apply_avg(self, col):
         """apply the SQL ``avg()`` function against the given column to the
         query and return the newly resulting ``Query``.
+
+        DEPRECATED.
         """
         return self._generative_col_aggregate(col, sql.func.avg)
 
@@ -852,6 +936,11 @@ class Query(object):
 
         context.runid = _new_runid()
 
+        # for with_polymorphic, instruct descendant mappers that they
+        # don't need to post-fetch anything
+        for m in self._with_polymorphic:
+            context.attributes[('polymorphic_fetch', m)] = (self.select_mapper, [])
+
         mappers_or_columns = tuple(self._entities) + mappers_or_columns
         tuples = bool(mappers_or_columns)
 
@@ -950,12 +1039,17 @@ class Query(object):
             ident = util.to_list(ident)
 
         q = self
+        
+        # dont use 'polymorphic' mapper if we are refreshing an instance
+        if refresh_instance and q.select_mapper is not q.mapper:
+            q = q._new_base_mapper(q.mapper, '_get')
+
         if ident is not None:
             q = q._no_criterion('get')
             params = {}
-            (_get_clause, _get_params) = self.select_mapper._get_clause
+            (_get_clause, _get_params) = q.select_mapper._get_clause
             q = q.filter(_get_clause)
-            for i, primary_key in enumerate(self.primary_key_columns):
+            for i, primary_key in enumerate(q.primary_key_columns):
                 try:
                     params[_get_params[primary_key].key] = ident[i]
                 except IndexError:
@@ -1027,25 +1121,10 @@ class Query(object):
             return context
 
         whereclause = self._criterion
-
         from_obj = self._from_obj
-
-        # if the query's ClauseAdapter is present, and its
-        # specifically adapting against a modified "select_from"
-        # argument, apply adaptation to the
-        # individually selected columns as well as "eager" clauses added;
-        # otherwise its currently not needed
-        if self._adapter and self.table not in self._get_joinable_tables():
-            adapter = self._adapter
-        else:
-            adapter = None
-
         adapter = self._adapter
-
-        # TODO: mappers added via add_entity(), adapt their queries also,
-        # if those mappers are polymorphic
-
         order_by = self._order_by
+
         if order_by is False:
             order_by = self.select_mapper.order_by
         if order_by is False:
@@ -1055,22 +1134,31 @@ class Query(object):
             if from_obj.default_order_by() is not None:
                 order_by = from_obj.default_order_by()
 
-        try:
-            for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[self._lockmode]
-        except KeyError:
-            raise exceptions.ArgumentError("Unknown lockmode '%s'" % self._lockmode)
-
+        if self._lockmode:
+            try:
+                for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[self._lockmode]
+            except KeyError:
+                raise exceptions.ArgumentError("Unknown lockmode '%s'" % self._lockmode)
+        else:
+            for_update = False
+            
         # if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so
         # that we only load the appropriate types
-        if self.select_mapper.single and self.select_mapper.polymorphic_on is not None and self.select_mapper.polymorphic_identity is not None:
+        if self.select_mapper.single and self.select_mapper.inherits is not None and self.select_mapper.polymorphic_on is not None and self.select_mapper.polymorphic_identity is not None:
             whereclause = sql.and_(whereclause, self.select_mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.select_mapper.polymorphic_iterator()]))
 
         context.from_clause = from_obj
 
-        # give all the attached properties a chance to modify the query
-        # TODO: doing this off the select_mapper.  if its the polymorphic mapper, then
-        # it has no relations() on it.  should we compile those too into the query ?  (i.e. eagerloads)
-        for value in self.select_mapper.iterate_properties:
+        # TODO: compile eagerloads from select_mapper if polymorphic ? [ticket:917]
+        if self._with_polymorphic:
+            props = util.Set()
+            for m in [self.select_mapper]  + self._with_polymorphic:
+                for value in m.iterate_properties:
+                    props.add(value)
+        else:
+            props = self.select_mapper.iterate_properties
+            
+        for value in props:
             if self._only_load_props and value.key not in self._only_load_props:
                 continue
             context.exec_with_path(self.select_mapper, value.key, value.setup, context, only_load_props=self._only_load_props)
@@ -1091,12 +1179,9 @@ class Query(object):
             # eager loaders are present, and the SELECT has limiting criterion
             # produce a "wrapped" selectable.
 
-            # ensure all 'order by' elements are ClauseElement instances
-            # (since they will potentially be aliased)
             # locate all embedded Column clauses so they can be added to the
             # "inner" select statement where they'll be available to the enclosing
             # statement's "order by"
-
             cf = util.Set()
             if order_by:
                 order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []]
@@ -1105,7 +1190,7 @@ class Query(object):
 
             if adapter:
                 # TODO: make usage of the ClauseAdapter here to create the list
-                # of primary columns
+                # of primary columns ?
                 context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns]
                 cf = [from_obj.corresponding_column(c) or c for c in cf]
 
@@ -1128,7 +1213,7 @@ class Query(object):
         else:
             if adapter:
                 # TODO: make usage of the ClauseAdapter here to create row adapter, list
-                # of primary columns
+                # of primary columns ?
                 context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns]
                 context.row_adapter = mapperutil.create_row_adapter(from_obj, self.table)
 
@@ -1425,13 +1510,12 @@ class Query(object):
 
         return self._legacy_filter_by(*args, **params).one()
 
-
 for deprecated_method in ('list', 'scalar', 'count_by',
                           'select_whereclause', 'get_by', 'select_by',
                           'join_by', 'selectfirst', 'selectone', 'select',
                           'execute', 'select_statement', 'select_text',
                           'join_to', 'join_via', 'selectfirst_by',
-                          'selectone_by'):
+                          'selectone_by', 'apply_max', 'apply_min', 'apply_avg', 'apply_sum'):
     setattr(Query, deprecated_method,
             util.deprecated(getattr(Query, deprecated_method),
                             add_deprecation_to_docstring=False))
index 9967f34f7e310972ed96714ca325052b6846a180..db8e313e6764d072febb326d943a8f1e40346a2a 100644 (file)
@@ -53,6 +53,7 @@ class GenerativeQueryTest(TestBase):
         assert list(query[-5:]) == orig[-5:]
         assert query[10:20][5] == orig[10:20][5]
 
+    @testing.uses_deprecated('Call to deprecated function apply_max')
     def test_aggregate(self):
         sess = create_session(bind=testing.db)
         query = sess.query(Foo)
@@ -77,6 +78,7 @@ class GenerativeQueryTest(TestBase):
         assert round(avg, 1) == 14.5
 
     @testing.fails_on('firebird', 'mssql')
+    @testing.uses_deprecated('Call to deprecated function apply_avg')
     def test_aggregate_3(self):
         query = create_session(bind=testing.db).query(Foo)
 
index 3571480292736b218a6340d5f743ca22e529c0d9..7d7b8b9d918f26acbc9c5bd24ef4dc0f9d62ce9d 100644 (file)
@@ -194,6 +194,19 @@ def make_test(select_type):
                 self.assertEquals(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1])
 
                 self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
+
+        def test_join_from_with_polymorphic(self):
+            sess = create_session()
+
+            for aliased in (True, False):
+                sess.clear()
+                self.assertEquals(sess.query(Person).with_polymorphic(Manager).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1])
+
+                sess.clear()
+                self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1])
+
+                sess.clear()
+                self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
     
         def test_join_to_polymorphic(self):
             sess = create_session()
@@ -223,7 +236,58 @@ def make_test(select_type):
                     sess.query(Company).filter(Company.employees.any(and_(Engineer.primary_language=='cobol', people.c.person_id==engineers.c.person_id))).one(),
                     c2
                     )
-                
+        
+        def test_expire(self):
+            """test that individual column refresh doesn't get tripped up by the select_table mapper"""
+            
+            sess = create_session()
+            m1 = sess.query(Manager).filter(Manager.name=='dogbert').one()
+            sess.expire(m1)
+            assert m1.status == 'regular manager'
+
+            m2 = sess.query(Manager).filter(Manager.name=='pointy haired boss').one()
+            sess.expire(m2, ['manager_name', 'golf_swing'])
+            assert m2.golf_swing=='fore'
+            
+        def test_with_polymorphic(self):
+            
+            sess = create_session()
+            
+            # compare to entities without related collections to prevent additional lazy SQL from firing on 
+            # loaded entities
+            emps_without_relations = [
+                Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer"),
+                Engineer(name="wally", engineer_name="wally", primary_language="c++", status="regular engineer"),
+                Boss(name="pointy haired boss", golf_swing="fore", manager_name="pointy", status="da boss"),
+                Manager(name="dogbert", manager_name="dogbert", status="regular manager"),
+                Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer")
+            ]
+            
+            def go():
+                self.assertEquals(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1])
+            self.assert_sql_count(testing.db, go, 1)
+            
+            sess.clear()
+            def go():
+                self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations)
+            self.assert_sql_count(testing.db, go, 1)
+
+            sess.clear()
+            def go():
+                self.assertEquals(sess.query(Person).with_polymorphic(Engineer).all(), emps_without_relations)
+            self.assert_sql_count(testing.db, go, 3)
+
+            sess.clear()
+            def go():
+                self.assertEquals(sess.query(Person).with_polymorphic(Engineer, people.outerjoin(engineers)).all(), emps_without_relations)
+            self.assert_sql_count(testing.db, go, 3)
+            
+            sess.clear()
+            def go():
+                # limit the polymorphic join down to just "Person", overriding select_table
+                self.assertEquals(sess.query(Person).with_polymorphic(Person).all(), emps_without_relations)
+            self.assert_sql_count(testing.db, go, 6)
+
         def test_join_to_subclass(self):
             sess = create_session()
 
index 41ae4446143e0d77adfd7fcc8e534b05d5839778..62bb99a3239a1496221edd05bd729d84b13d9887 100644 (file)
@@ -389,6 +389,7 @@ class AggregateTest(QueryTest):
         orders = sess.query(Order).filter(Order.id.in_([2, 3, 4]))
         assert orders.sum(Order.user_id * Order.address_id) == 79
 
+    @testing.uses_deprecated('Call to deprecated function apply_sum')
     def test_apply(self):
         sess = create_session()
         assert sess.query(Order).apply_sum(Order.user_id * Order.address_id).filter(Order.id.in_([2, 3, 4])).one() == 79