]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added a new "higher level" operator called "of_type()" -
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 21 Feb 2008 01:01:24 +0000 (01:01 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 21 Feb 2008 01:01:24 +0000 (01:01 +0000)
used in join() as well as with any() and has(), qualifies
the subclass which will be used in filter criterion,
e.g.:

query.filter(Company.employees.of_type(Engineer).
  any(Engineer.name=='foo')),

query.join(Company.employees.of_type(Engineer)).
  filter(Engineer.name=='foo')

CHANGES
doc/build/content/mappers.txt
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
test/orm/inheritance/query.py

diff --git a/CHANGES b/CHANGES
index 101b7e6e32f662d549c3dd9b23e2dc0987f4bf41..988c0403d1ff3ff49e5c5c3f8da2402cda04ca9e 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -20,6 +20,16 @@ CHANGES
       it out.
     - fixed potential generative bug when the same Query was
       used to generate multiple Query objects using join().
+    - added a new "higher level" operator called "of_type()" - 
+      used in join() as well as with any() and has(), qualifies
+      the subclass which will be used in filter criterion, 
+      e.g.: 
+      
+        query.filter(Company.employees.of_type(Engineer).
+          any(Engineer.name=='foo')), 
+      
+        query.join(Company.employees.of_type(Engineer)).
+          filter(Engineer.name=='foo')
       
 0.4.4
 ------
index 41cae69666de62c9fc76955dacda1a552eed3b5f..a842bc3df5d4aba6c4c4cecb5cb6035f9bd0efc1 100644 (file)
@@ -339,6 +339,65 @@ We then configure mappers as usual, except we use some additional arguments to i
 
 And that's it.  Querying against `Employee` will return a combination of `Employee`, `Engineer` and `Manager` objects.
 
+###### Polymorphic Querying Strategies {@name=querying}
+
+The `Query` object includes some helper functionality when dealing with joined-table inheritance mappings.  These helpers apply mostly to the `join()` method, as well as the special `any()` and `has()` operators.
+
+Suppose the `employees` table represents a collection of employees which are associated with a `Company` object.  We'll add a `company_id` column to the `employees` table and a new table `companies`:
+
+    {python}
+    companies = Table('companies', metadata,
+       Column('company_id', Integer, primary_key=True),
+       Column('name', String(50))
+       )
+
+    employees = Table('employees', metadata, 
+      Column('employee_id', Integer, primary_key=True),
+      Column('name', String(50)),
+      Column('type', String(30), nullable=False),
+      Column('company_id', Integer, ForeignKey('companies.company_id'))
+    )
+    
+    class Company(object):
+        pass
+    
+    mapper(Company, companies, properties={
+        'employees':relation(Employee)
+    })
+    
+If we wanted to join from `Company` to not just `Employee` but specifically `Engineers`, using the `join()` method or `any()` or `has()` operators will by default create a join from `companies` to `employees`, without including `engineers` or `managers` in the mix.  If we wish to have criterion which is specifically against the `Engineer` class, extra instruction is needed.  As of version 0.4.4 we can use this notation:
+
+    {python}
+    session.query(Company).join(Company.employees.of_type(Engineer)).filter(Engineer.engineer_info=='someinfo')
+    
+A longhand notation, introduced in 0.4.3, is also available, which involves spelling out the full target selectable within a 2-tuple:
+
+    {python}
+    session.query(Company).join(('employees', employees.join(engineers))).filter(Engineer.engineer_info=='someinfo')
+    
+The second notation allows more flexibility, such as joining to any group of subclass tables:
+
+    {python}
+    session.query(Company).join(('employees', employees.outerjoin(engineers).outerjoin(managers))).\
+        filter(or_(Engineer.engineer_info=='someinfo', Manager.manager_data=='somedata'))
+
+The `any()` and `has()` operators also can be used with `of_type()` when the embedded criterion is in terms of a subclass:
+
+    {python}
+    session.query(Company).filter(Company.employees.of_type(Engineer).any(Engineer.engineer_info=='someinfo')).all()
+    
+Note that these two operators are shorthand for a correlated EXISTS query.  To build one by hand looks like:
+
+    {python}
+    session.query(Company).filter(
+        exists([1], 
+            and_(Engineer.engineer_info=='someinfo', employees.c.company_id==companies.c.company_id), 
+            from_obj=employees.join(engineers)
+        )
+    ).all()
+
+Where the EXISTS query selects from the join of `employees` to `engineers`, and also specifies criterion which correlates the exists subselect back to the parent `companies` table.
+    
 ###### Optimizing Joined Table Loads {@name=optimizing}
 
 When loading fresh from the database, the joined-table setup above will query from the parent table first, then for each row will issue a second query to the child table.  For example, for a load of five rows with `Employee` id 3, `Manager` ids 1 and 5 and `Engineer` ids 2 and 4, will produce queries along the lines of this example:
index 2c3a98c88fe8c976a5c65268457e8b322889dda6..f510a3ffaf6b285fffeb1441e7e7e38d6fd567ab 100644 (file)
@@ -435,8 +435,29 @@ class PropComparator(expression.ColumnOperators):
     has_op = staticmethod(has_op)
 
     def __init__(self, prop):
-        self.prop = prop
-
+        self.prop = self.property = prop
+    
+    def of_type_op(a, class_):
+        return a.of_type(class_)
+    of_type_op = staticmethod(of_type_op)
+    
+    def of_type(self, class_):
+        """Redefine this object in terms of a polymorphic subclass.
+        
+        Returns a new PropComparator from which further criterion can be evaulated.
+        
+        class_
+          a class or mapper indicating that criterion will be against
+          this specific subclass.
+        
+        e.g.::
+          query.join(Company.employees.of_type(Engineer)).\
+             filter(Engineer.name=='foo')
+         
+        """
+        
+        return self.operate(PropComparator.of_type_op, class_)
+        
     def contains(self, other):
         """Return true if this collection contains other"""
         return self.operate(PropComparator.contains_op, other)
index 85aec2f4476fc9a4011361d7483ba339b88e4e81..c0a15c427e46082d9f1c254fe01051e7e371efb6 100644 (file)
@@ -1628,3 +1628,12 @@ def class_mapper(class_, entity_name=None, compile=True, raiseerror=True):
         return mapper.compile()
     else:
         return mapper
+
+def _class_to_mapper(class_or_mapper, entity_name=None, compile=True):
+    if isinstance(class_or_mapper, type):
+        return class_mapper(class_or_mapper, entity_name=entity_name, compile=compile)
+    else:
+        if compile:
+            return class_or_mapper.compile()
+        else:
+            return class_or_mapper
index 6339ec5750ae28784c72e223e95b57115fd5f750..74d4c04ca45bf62d821f2069c736f8336ae1f970 100644 (file)
@@ -15,6 +15,7 @@ from sqlalchemy.sql.util import ClauseAdapter, ColumnsInClause
 from sqlalchemy.sql import visitors, operators, ColumnElement
 from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper
 from sqlalchemy.orm import session as sessionlib
+from sqlalchemy.orm.mapper import _class_to_mapper
 from sqlalchemy.orm.util import CascadeOptions, PropertyAliasedClauses
 from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty
 from sqlalchemy.exceptions import ArgumentError
@@ -244,6 +245,14 @@ class PropertyLoader(StrategizedProperty):
         self.is_backref = is_backref
 
     class Comparator(PropComparator):
+        def __init__(self, prop, of_type=None):
+            self.prop = self.property = prop
+            if of_type:
+                self._of_type = _class_to_mapper(of_type)
+        
+        def of_type(self, cls):
+            return PropertyLoader.Comparator(self.prop, cls)
+            
         def __eq__(self, other):
             if other is None:
                 if self.prop.uselist:
@@ -267,17 +276,28 @@ class PropertyLoader(StrategizedProperty):
                 return self.prop._optimized_compare(other)
         
         def _join_and_criterion(self, criterion=None, **kwargs):
+            adapt_against = None
+
+            if getattr(self, '_of_type', None):
+                target_mapper = self._of_type
+                to_selectable = target_mapper.select_table
+                adapt_against = to_selectable
+            else:
+                target_mapper = self.prop.mapper
+                to_selectable = None
+                if target_mapper.select_table is not target_mapper.mapped_table:
+                    adapt_against = target_mapper.select_table
+                
             if self.prop._is_self_referential():
-                pac = PropertyAliasedClauses(self.prop,
-                        self.prop.primaryjoin,
-                        self.prop.secondaryjoin)
+                pj = self.prop.primary_join_against(self.prop.parent, None)
+                sj = self.prop.secondary_join_against(self.prop.parent, toselectable=to_selectable)
+
+                pac = PropertyAliasedClauses(self.prop, pj, sj)
                 j = pac.primaryjoin
                 if pac.secondaryjoin:
                     j = j & pac.secondaryjoin
             else:
-                j = self.prop.primaryjoin
-                if self.prop.secondaryjoin:
-                    j = j & self.prop.secondaryjoin
+                j = self.prop.full_join_against(self.prop.parent, None, toselectable=to_selectable)
 
             for k in kwargs:
                 crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
@@ -285,25 +305,28 @@ class PropertyLoader(StrategizedProperty):
                     criterion = crit
                 else:
                     criterion = criterion & crit
-
-            if criterion and self.prop._is_self_referential():
-                criterion = pac.adapt_clause(criterion)
             
-            return j, criterion
+            if criterion:
+                if adapt_against:
+                    criterion = ClauseAdapter(adapt_against).traverse(criterion)
+                if self.prop._is_self_referential():
+                    criterion = pac.adapt_clause(criterion)
+            
+            return j, criterion, to_selectable
             
         def any(self, criterion=None, **kwargs):
             if not self.prop.uselist:
                 raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
-            j, criterion = self._join_and_criterion(criterion, **kwargs)
+            j, criterion, from_obj = self._join_and_criterion(criterion, **kwargs)
 
-            return sql.exists([1], j & criterion)
+            return sql.exists([1], j & criterion, from_obj=from_obj)
 
         def has(self, criterion=None, **kwargs):
             if self.prop.uselist:
                 raise exceptions.InvalidRequestError("'has()' not implemented for collections.  Use any().")
-            j, criterion = self._join_and_criterion(criterion, **kwargs)
+            j, criterion, from_obj = self._join_and_criterion(criterion, **kwargs)
 
-            return sql.exists([1], j & criterion)
+            return sql.exists([1], j & criterion, from_obj=from_obj)
 
         def contains(self, other):
             if not self.prop.uselist:
@@ -322,9 +345,9 @@ class PropertyLoader(StrategizedProperty):
                 raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
             
             criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])
-            j, criterion = self._join_and_criterion(criterion)
+            j, criterion, from_obj = self._join_and_criterion(criterion)
 
-            return ~sql.exists([1], j & criterion)
+            return ~sql.exists([1], j & criterion, from_obj=from_obj)
 
     def compare(self, op, value, value_is_parent=False):
         if op == operators.eq:
@@ -700,50 +723,59 @@ class PropertyLoader(StrategizedProperty):
     def _is_self_referential(self):
         return self.parent.mapped_table is self.target or self.parent.select_table is self.target
 
-    def primary_join_against(self, mapper, selectable=None):
-        return self.__cached_join_against(mapper, selectable, True, False)
+    def primary_join_against(self, mapper, selectable=None, toselectable=None):
+        return self.__cached_join_against(mapper, selectable, toselectable, True, False)
         
-    def secondary_join_against(self, mapper):
-        return self.__cached_join_against(mapper, None, False, True)
+    def secondary_join_against(self, mapper, toselectable=None):
+        return self.__cached_join_against(mapper, None, toselectable, False, True)
         
-    def full_join_against(self, mapper, selectable=None):
-        return self.__cached_join_against(mapper, selectable, True, True)
+    def full_join_against(self, mapper, selectable=None, toselectable=None):
+        return self.__cached_join_against(mapper, selectable, toselectable, True, True)
     
-    def __cached_join_against(self, mapper, selectable, primary, secondary):
-        if selectable is None:
-            selectable = mapper.local_table
+    def __cached_join_against(self, frommapper, fromselectable, toselectable, primary, secondary):
+        if fromselectable is None:
+            fromselectable = frommapper.local_table
             
         try:
-            rec = self.__parent_join_cache[selectable]
+            rec = self.__parent_join_cache[fromselectable]
         except KeyError:
-            self.__parent_join_cache[selectable] = rec = {}
+            self.__parent_join_cache[fromselectable] = rec = {}
 
-        key = (mapper, primary, secondary)
+        key = (frommapper, primary, secondary, toselectable)
         if key in rec:
             return rec[key]
         
-        parent_equivalents = mapper._equivalent_columns
+        parent_equivalents = frommapper._equivalent_columns
         
         if primary:
-            if selectable is not mapper.local_table:
+            if toselectable:
+                primaryjoin = self.primaryjoin
+            else:
+                primaryjoin = self.polymorphic_primaryjoin
+                
+            if fromselectable is not frommapper.local_table:
                 if self.direction is sync.ONETOMANY:
-                    primaryjoin = ClauseAdapter(selectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin)
+                    primaryjoin = ClauseAdapter(fromselectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
                 elif self.direction is sync.MANYTOONE:
-                    primaryjoin = ClauseAdapter(selectable, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin)
+                    primaryjoin = ClauseAdapter(fromselectable, include=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
                 elif self.secondaryjoin:
-                    primaryjoin = ClauseAdapter(selectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin)
-            else:
-                primaryjoin = self.polymorphic_primaryjoin
+                    primaryjoin = ClauseAdapter(fromselectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
                 
             if secondary:
-                secondaryjoin = self.polymorphic_secondaryjoin
+                if toselectable:
+                    secondaryjoin = self.secondaryjoin
+                else:
+                    secondaryjoin = self.polymorphic_secondaryjoin
                 rec[key] = ret = primaryjoin & secondaryjoin
             else:
                 rec[key] = ret = primaryjoin
             return ret
         
         elif secondary:
-            rec[key] = ret = self.polymorphic_secondaryjoin
+            if toselectable:
+                rec[key] = ret = self.secondaryjoin
+            else:
+                rec[key] = ret = self.polymorphic_secondaryjoin
             return ret
 
         else:
index 9c917ec2d5b56c5989ace3450365717e7d618e4a..46f986d14e0b012d8deac9da3aa33cea89ab2d54 100644 (file)
@@ -22,21 +22,18 @@ from sqlalchemy import sql, util, exceptions, logging
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.sql import expression, visitors, operators
 from sqlalchemy.orm import mapper, object_mapper
-from sqlalchemy.orm.mapper import _state_mapper
+from sqlalchemy.orm.mapper import _state_mapper, _class_to_mapper
 from sqlalchemy.orm import util as mapperutil
 from sqlalchemy.orm import interfaces
 
 __all__ = ['Query', 'QueryContext']
 
-
+    
 class Query(object):
     """Encapsulates the object-fetching operations provided by Mappers."""
 
     def __init__(self, class_or_mapper, session=None, entity_name=None):
-        if isinstance(class_or_mapper, type):
-            self.mapper = mapper.class_mapper(class_or_mapper, entity_name=entity_name)
-        else:
-            self.mapper = class_or_mapper.compile()
+        self.mapper = _class_to_mapper(class_or_mapper, entity_name=entity_name)
         self.select_mapper = self.mapper.get_select_mapper().compile()
 
         self._session = session
@@ -422,15 +419,23 @@ class Query(object):
 
         mapper = start
         alias = self._aliases
+        
         if not isinstance(keys, list):
             keys = [keys]
         for key in keys:
             use_selectable = None
+            of_type = None
+
             if isinstance(key, tuple):
                 key, use_selectable = key
 
             if isinstance(key, interfaces.PropComparator):
                 prop = key.property
+                if getattr(key, '_of_type', None):
+                    if use_selectable:
+                        raise exceptions.InvalidRequestError("Can't specify use_selectable along with polymorphic property created via of_type().")
+                    of_type = key._of_type
+                    use_selectable = key._of_type.select_table
             else:
                 prop = mapper.get_property(key, resolve_synonyms=True)
 
@@ -445,45 +450,29 @@ class Query(object):
 
             if prop.select_table not in currenttables or create_aliases or use_selectable:
                 if prop.secondary:
-                    if use_selectable:
+                    if use_selectable or create_aliases:
                         alias = mapperutil.PropertyAliasedClauses(prop,
                             prop.primary_join_against(mapper, adapt_against),
-                            prop.secondary_join_against(mapper),
+                            prop.secondary_join_against(mapper, toselectable=use_selectable),
                             alias,
                             alias=use_selectable
                         )
                         crit = alias.primaryjoin
                         clause = clause.join(alias.secondary, crit, isouter=outerjoin).join(alias.alias, alias.secondaryjoin, isouter=outerjoin)
-                    elif create_aliases:
-                        alias = mapperutil.PropertyAliasedClauses(prop,
-                            prop.primary_join_against(mapper, adapt_against),
-                            prop.secondary_join_against(mapper),
-                            alias
-                        )
-                        crit = alias.primaryjoin
-                        clause = clause.join(alias.secondary, crit, isouter=outerjoin).join(alias.alias, alias.secondaryjoin, isouter=outerjoin)
                     else:
                         crit = prop.primary_join_against(mapper, adapt_against)
                         clause = clause.join(prop.secondary, crit, isouter=outerjoin)
                         clause = clause.join(prop.select_table, prop.secondary_join_against(mapper), isouter=outerjoin)
                 else:
-                    if use_selectable:
+                    if use_selectable or create_aliases:
                         alias = mapperutil.PropertyAliasedClauses(prop,
-                            prop.primary_join_against(mapper, adapt_against),
+                            prop.primary_join_against(mapper, adapt_against, toselectable=use_selectable),
                             None,
                             alias,
                             alias=use_selectable
                         )
                         crit = alias.primaryjoin
                         clause = clause.join(alias.alias, crit, isouter=outerjoin)
-                    elif create_aliases:
-                        alias = mapperutil.PropertyAliasedClauses(prop,
-                            prop.primary_join_against(mapper, adapt_against),
-                            None,
-                            alias
-                        )
-                        crit = alias.primaryjoin
-                        clause = clause.join(alias.alias, crit, isouter=outerjoin)
                     else:
                         crit = prop.primary_join_against(mapper, adapt_against)
                         clause = clause.join(prop.select_table, crit, isouter=outerjoin)
@@ -492,7 +481,7 @@ class Query(object):
                 # does not use secondary tables
                 raise exceptions.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists.  Use the `alias=True` argument to `join()`." % prop.key)
 
-            mapper = prop.mapper
+            mapper = of_type or prop.mapper
 
             if use_selectable:
                 adapt_against = use_selectable
@@ -707,7 +696,7 @@ class Query(object):
             q._adapter = sql_util.ClauseAdapter(q._from_obj, equivalents=q.mapper._equivalent_columns)
         return q
 
-
+    
     def select_from(self, from_obj):
         """Set the `from_obj` parameter of the query and return the newly
         resulting ``Query``.  This replaces the table which this Query selects
index 503364787264560ab4a26b15d542c2ec15ce144c..3571480292736b218a6340d5f743ca22e529c0d9 100644 (file)
@@ -200,7 +200,30 @@ def make_test(select_type):
             self.assertEquals(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2)
 
             self.assertEquals(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2)
-    
+        
+        def test_polymorphic_any(self):
+            sess = create_session()
+
+            self.assertEquals(
+                sess.query(Company).filter(Company.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).one(),
+                c2
+                )
+
+            self.assertEquals(
+                sess.query(Company).filter(Company.employees.of_type(Boss).any(Boss.golf_swing=='fore')).one(),
+                c1
+                )
+            self.assertEquals(
+                sess.query(Company).filter(Company.employees.of_type(Boss).any(Manager.manager_name=='pointy')).one(),
+                c1
+                )
+
+            if select_type == '':
+                self.assertEquals(
+                    sess.query(Company).filter(Company.employees.any(and_(Engineer.primary_language=='cobol', people.c.person_id==engineers.c.person_id))).one(),
+                    c2
+                    )
+                
         def test_join_to_subclass(self):
             sess = create_session()
 
@@ -218,6 +241,16 @@ def make_test(select_type):
                 self.assertEquals(sess.query(Person).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
                 self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).all(), [c1, c2])
                 self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
+            
+            # non-polymorphic
+            self.assertEquals(sess.query(Engineer).join(Engineer.machines).all(), [e1, e2, e3])
+            self.assertEquals(sess.query(Engineer).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
+
+            # here's the new way
+            self.assertEquals(sess.query(Company).join(Company.employees.of_type(Engineer)).filter(Engineer.primary_language=='java').all(), [c1])
+            self.assertEquals(sess.query(Company).join([Company.employees.of_type(Engineer), 'machines']).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
+
+
         
         def test_join_through_polymorphic(self):