]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Query.count() will take single-table inheritance
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 4 Jun 2008 19:25:35 +0000 (19:25 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 4 Jun 2008 19:25:35 +0000 (19:25 +0000)
subtypes into account the same way row-based
results do. (ticket:1008]. partial merge of 0.5's r4831.)

CHANGES
lib/sqlalchemy/orm/query.py
test/orm/inheritance/single.py

diff --git a/CHANGES b/CHANGES
index 3a9e678098dc1989218bf68862894f0d87190a38..ec85febbb518506ef37379bdef337c8899bf321f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -11,6 +11,10 @@ CHANGES
 
     - fixed bug preventing merge() from functioning in 
       conjunction with a comparable_property()
+
+    - Query.count() will take single-table inheritance
+      subtypes into account the same way row-based
+      results do. [ticket:1008]
       
 - mysql
     - Added 'CALL' to the list of SQL keywords which return
index 8996a758e6c2954490fa7d76e10ff108762fc130..6fffbddb8c91f8429c176c3ce720bb16b59e3544 100644 (file)
@@ -1084,9 +1084,12 @@ class Query(object):
         return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self.mapper.primary_key))
 
     def _col_aggregate(self, col, func, nested_cols=None):
-        whereclause = self._criterion
         
         context = QueryContext(self)
+        context.whereclause = self._criterion
+        self._adjust_for_single_inheritance(context)
+        whereclause = context.whereclause 
+        
         from_obj = self._from_obj
 
         if self._should_nest_selectable:
@@ -1130,6 +1133,8 @@ class Query(object):
         context.from_clause = from_obj
         context.whereclause = self._criterion
         context.order_by = self._order_by
+
+        self._adjust_for_single_inheritance(context)
         
         for entity in self._entities:
             entity.setup_context(self, context)
@@ -1207,6 +1212,22 @@ class Query(object):
 
         return context
 
+    def _adjust_for_single_inheritance(self, context):
+        """Apply single-table-inheritance filtering.
+        
+        For the base mapper of this query, add criterion to the WHERE clause of the given QueryContext
+        such that only the appropriate subtypes are selected from the total results.
+
+        A more sophisticated version of this, which works with multiple mappers and column expressions,
+        is present in 0.5.
+
+        """
+        # if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so
+        # that we only load the appropriate types
+        if self.mapper.single and self.mapper.inherits is not None and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None:
+            context.whereclause = sql.and_(context.whereclause, self.mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.mapper.polymorphic_iterator()]))
+
+
     def __log_debug(self, msg):
         self.logger.debug(msg)
 
@@ -1554,10 +1575,6 @@ class _PrimaryMapperEntity(_MapperEntity):
         return main
 
     def setup_context(self, query, context):
-        # if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so
-        # that we only load the appropriate types
-        if self.mapper.single and self.mapper.inherits is not None and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None:
-            context.whereclause = sql.and_(context.whereclause, self.mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.mapper.polymorphic_iterator()]))
         
         if context.order_by is False:
             if self.mapper.order_by:
index 2241afb0f8d9cbd6781a7e045fd7b33342b94de1..631dec24e5795a9bb3aeb4054a5e243bf53aaaf0 100644 (file)
@@ -44,6 +44,8 @@ class SingleInheritanceTest(ORMTest):
         assert session.query(Engineer).all() == [e1, e2]
         assert session.query(Manager).all() == [m1]
         assert session.query(JuniorEngineer).all() == [e2]
+        
+        assert session.query(Engineer).count() == 2
 
 class SingleOnJoinedTest(ORMTest):
     def define_tables(self, metadata):