]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- specifying joins in the from_obj argument of query.select() will
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 27 Sep 2006 07:08:26 +0000 (07:08 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 27 Sep 2006 07:08:26 +0000 (07:08 +0000)
replace the main table of the query, if the table is somewhere within
the given from_obj.  this makes it possible to produce custom joins and
outerjoins in queries without the main table getting added twice.
[ticket:315]
- added join_to and outerjoin_to transformative methods to SelectResults,
to build up join/outerjoin conditions based on property names. also
added select_from to explicitly set from_obj parameter.
- factored "results" arrays from the mapper test suite and into the
"tables" mapper
- added "viewonly" param to docs

CHANGES
doc/build/content/adv_datamapping.txt
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/ext/selectresults.py
lib/sqlalchemy/orm/query.py
test/ext/selectresults.py
test/orm/mapper.py
test/tables.py
test/testbase.py

diff --git a/CHANGES b/CHANGES
index 3139c57f09ce57a3499cf5b5b9b1765258e9d9c3..8ed6d97720a563e87bc1e239321d3b9ff998356d 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -29,6 +29,8 @@ automatic "row switch" feature
 - changed "for_update" parameter to accept False/True/"nowait"
 and "read", the latter two of which are interpreted only by
 Oracle and Mysql [ticket:292]
+- added "viewonly" flag to relation(), allows construction of
+relations that have no effect on the flush() process.
 - added "lockmode" argument to base Query select/get functions, 
 including "with_lockmode" function to get a Query copy that has 
 a default locking mode.  Will translate "read"/"update" 
@@ -73,6 +75,14 @@ of all persistent objects where the attribute manager detects changes.
 The basic issue thats fixed is detecting changes on PickleType 
 objects, but also generalizes type handling and "modified" object
 checking to be more complete and extensible.
+- specifying joins in the from_obj argument of query.select() will
+replace the main table of the query, if the table is somewhere within
+the given from_obj.  this makes it possible to produce custom joins and
+outerjoins in queries without the main table getting added twice.
+[ticket:315]
+- added join_to and outerjoin_to transformative methods to SelectResults,
+to build up join/outerjoin conditions based on property names. also
+added select_from to explicitly set from_obj parameter.
 
 0.2.8
 - cleanup on connection methods + documentation.  custom DBAPI
index f1afd5c6ee45861fb395df56e16b61d9c9ce59df..8e6cdc5c63cb05e83d8f671d3297a9b59380f5b7 100644 (file)
@@ -255,6 +255,7 @@ Keyword options to the `relation` function include:
 * order_by - indicates the ordering that should be applied when loading these items.  See the section [advdatamapping_orderby](rel:advdatamapping_orderby) for details.
 * association - When specifying a many to many relationship with an association object, this keyword should reference the mapper or class of the target object of the association.  See the example in [datamapping_association](rel:datamapping_association).
 * post_update - this indicates that the relationship should be handled by a second UPDATE statement after an INSERT, or before a DELETE.  using this flag essentially means the relationship will not incur any "dependency" between parent and child item, as the particular foreign key relationship between them is handled by a second statement.  use this flag when a particular mapping arrangement will incur two rows that are dependent on each other, such as a table that has a one-to-many relationship to a set of child rows, and also has a column that references a single child row within that list (i.e. both tables contain a foreign key to each other).  If a flush() operation returns an error that a "cyclical dependency" was detected, this is a cue that you might want to use post_update.
+* viewonly=(True|False) - when set to True, the relation is used only for loading objects within the relationship, and has no effect on the unit-of-work flush process.  relations with viewonly can specify any kind of join conditions to provide additional views of related objects onto a parent object.
 
 ### Controlling Ordering {@name=orderby}
 
index 5db157cbb191d7f7b2b38bb997ea121c9e396b2c..e6d331109c27ff8072d51a4e52c5674d923b93d8 100644 (file)
@@ -298,6 +298,7 @@ class OracleCompiler(ansisql.ANSICompiler):
         
         self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right)
         self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause)
+        self.strings[join] = self.froms[join]
 
         if join.isouter:
             # if outer join, push on the right side table as the current "outertable"
index a35cdfa7eea205433c0528d8b3d8195e2bf5995f..93698a5e01927fd25b528235fe78a084dbe57567 100644 (file)
@@ -18,13 +18,14 @@ class SelectResults(object):
     instance with further limiting criterion added. When interpreted
     in an iterator context (such as via calling list(selectresults)), executes the query."""
     
-    def __init__(self, query, clause=None, ops={}):
+    def __init__(self, query, clause=None, ops={}, joinpoint=None):
         """constructs a new SelectResults using the given Query object and optional WHERE 
         clause.  ops is an optional dictionary of bind parameter values."""
         self._query = query
         self._clause = clause
         self._ops = {}
         self._ops.update(ops)
+        self._joinpoint = joinpoint or (self._query.table, self._query.mapper)
 
     def count(self):
         """executes the SQL count() function against the SelectResults criterion."""
@@ -60,7 +61,7 @@ class SelectResults(object):
 
     def clone(self):
         """creates a copy of this SelectResults."""
-        return SelectResults(self._query, self._clause, self._ops.copy())
+        return SelectResults(self._query, self._clause, self._ops.copy(), self._joinpoint)
         
     def filter(self, clause):
         """applies an additional WHERE clause against the query."""
@@ -68,23 +69,76 @@ class SelectResults(object):
         new._clause = sql.and_(self._clause, clause)
         return new
 
+    def select(self, clause):
+        return self.filter(clause)
+        
     def order_by(self, order_by):
-        """applies an ORDER BY to the query."""
+        """apply an ORDER BY to the query."""
         new = self.clone()
         new._ops['order_by'] = order_by
         return new
 
     def limit(self, limit):
-        """applies a LIMIT to the query."""
+        """apply a LIMIT to the query."""
         return self[:limit]
 
     def offset(self, offset):
-        """applies an OFFSET to the query."""
+        """apply an OFFSET to the query."""
         return self[offset:]
 
     def list(self):
-        """returns the results represented by this SelectResults as a list.  this results in an execution of the underlying query."""
+        """return the results represented by this SelectResults as a list.  
+        
+        this results in an execution of the underlying query."""
         return list(self)
+    
+    def select_from(self, from_obj):
+        """set the from_obj parameter of the query to a specific table or set of tables.
+        
+        from_obj is a list."""
+        new = self.clone()
+        new._ops['from_obj'] = from_obj
+        return new
+        
+    def join_to(self, prop):
+        """join the table of this SelectResults to the table located against the given property name.
+        
+        subsequent calls to join_to or outerjoin_to will join against the rightmost table located from the 
+        previous join_to or outerjoin_to call, searching for the property starting with the rightmost mapper
+        last located."""
+        new = self.clone()
+        (clause, mapper) = self._join_to(prop, outerjoin=False)
+        new._ops['from_obj'] = [clause]
+        new._joinpoint = (clause, mapper)
+        return new
+        
+    def outerjoin_to(self, prop):
+        """outer join the table of this SelectResults to the table located against the given property name.
+        
+        subsequent calls to join_to or outerjoin_to will join against the rightmost table located from the 
+        previous join_to or outerjoin_to call, searching for the property starting with the rightmost mapper
+        last located."""
+        new = self.clone()
+        (clause, mapper) = self._join_to(prop, outerjoin=True)
+        new._ops['from_obj'] = [clause]
+        new._joinpoint = (clause, mapper)
+        return new
+    
+    def _join_to(self, prop, outerjoin=False):
+        [keys,p] = self._query._locate_prop(prop, start=self._joinpoint[1])
+        clause = self._joinpoint[0]
+        mapper = self._joinpoint[1]
+        for key in keys:
+            prop = mapper.props[key]
+            if outerjoin:
+                clause = clause.outerjoin(prop.mapper.mapped_table, prop.get_join())
+            else:
+                clause = clause.join(prop.mapper.mapped_table, prop.get_join())
+            mapper = prop.mapper
+        return (clause, mapper)
+        
+    def compile(self):
+        return self._query.compile(self._clause, **self._ops)
         
     def __getitem__(self, item):
         if isinstance(item, slice):
index 0a6a050055826bce72f4bf87c8e33d61505345b9..db497bc37f093aa0fc03c491472a7719d5e0c596 100644 (file)
@@ -1,4 +1,4 @@
- # orm/query.py
+# orm/query.py
 # Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com
 #
 # This module is part of SQLAlchemy and is released under
@@ -112,7 +112,7 @@ class Query(object):
                 clause &= c
         return clause
 
-    def _locate_prop(self, key):
+    def _locate_prop(self, key, start=None):
         import properties
         keys = []
         seen = util.Set()
@@ -137,7 +137,7 @@ class Query(object):
                         return x
                 else:
                     return None
-        p = search_for_prop(self.mapper)
+        p = search_for_prop(start or self.mapper)
         if p is None:
             raise exceptions.InvalidRequestError("Cant locate property named '%s'" % key)
         return [keys, p]
@@ -163,10 +163,9 @@ class Query(object):
             else:
                 clause &= prop.get_join()
             mapper = prop.mapper
-            
+
         return clause
-    
-        
+
     def selectfirst_by(self, *args, **params):
         """works like select_by(), but only returns the first result by itself, or None if no 
         objects returned.  Synonymous with get_by()"""
@@ -340,10 +339,15 @@ class Query(object):
         
         if self.mapper.single and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None:
             whereclause = sql.and_(whereclause, self.mapper.polymorphic_on==self.mapper.polymorphic_identity)
+        
+        alltables = []
+        for l in [sql_util.TableFinder(x) for x in from_obj]:
+            alltables += l
             
-        if self._should_nest(**kwargs):
+        if self.table not in alltables:
             from_obj.append(self.table)
             
+        if self._should_nest(**kwargs):
             # if theres an order by, add those columns to the column list
             # of the "rowcount" query we're going to make
             if order_by:
@@ -375,7 +379,6 @@ class Query(object):
                 [o.accept_visitor(aliasizer) for  o in order_by]
                 statement.order_by(*util.to_list(order_by))
         else:
-            from_obj.append(self.table)
             statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **kwargs)
             if order_by:
                 statement.order_by(*util.to_list(order_by))
index 6997dfe6bb520665d464305d34a28f06225614da..15d88e7b6756abf4a750b33ff12fa8e09bd86174 100644 (file)
@@ -1,9 +1,10 @@
-from testbase import PersistTest
+from testbase import PersistTest, AssertMixin
 import testbase
+import tables
 
 from sqlalchemy import *
 
-from sqlalchemy.ext.selectresults import SelectResultsExt
+from sqlalchemy.ext.selectresults import SelectResultsExt, SelectResults
 
 class Foo(object):
     pass
@@ -122,7 +123,53 @@ class SelectResultsTest2(PersistTest):
         res = self.query.select(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1), distinct=True)
         self.assertEqual(res.count(), 1)
 
-class SelectResultsTest3(PersistTest):
+class RelationsTest(AssertMixin):
+    def setUpAll(self):
+        tables.create()
+        tables.data()
+    def tearDownAll(self):
+        tables.drop()
+    def tearDown(self):
+        clear_mappers()
+    def test_jointo(self):
+        """test the join_to and outerjoin_to functions on SelectResults"""
+        mapper(tables.User, tables.users, properties={
+            'orders':relation(mapper(tables.Order, tables.orders, properties={
+                'items':relation(mapper(tables.Item, tables.orderitems))
+            }))
+        })
+        session = create_session()
+        query = SelectResults(session.query(tables.User))
+        x = query.join_to('orders').join_to('items').select(tables.Item.c.item_id==2)
+        print x.compile()
+        self.assert_result(list(x), tables.User, tables.user_result[2])
+    def test_outerjointo(self):
+        """test the join_to and outerjoin_to functions on SelectResults"""
+        mapper(tables.User, tables.users, properties={
+            'orders':relation(mapper(tables.Order, tables.orders, properties={
+                'items':relation(mapper(tables.Item, tables.orderitems))
+            }))
+        })
+        session = create_session()
+        query = SelectResults(session.query(tables.User))
+        x = query.outerjoin_to('orders').outerjoin_to('items').select(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
+        print x.compile()
+        self.assert_result(list(x), tables.User, *tables.user_result[1:3])
+    def test_from(self):
+        mapper(tables.User, tables.users, properties={
+            'orders':relation(mapper(tables.Order, tables.orders, properties={
+                'items':relation(mapper(tables.Item, tables.orderitems))
+            }))
+        })
+        session = create_session()
+        query = SelectResults(session.query(tables.User))
+        x = query.select_from([tables.users.outerjoin(tables.orders).outerjoin(tables.orderitems)]).\
+            filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
+        print x.compile()
+        self.assert_result(list(x), tables.User, *tables.user_result[1:3])
+        
+
+class CaseSensitiveTest(PersistTest):
     def setUpAll(self):
         self.install_threadlocal()
         global metadata, table1, table2
index 786f410e59147975b7ac4f7929b499322e28a939..12059d28e228f23ed563320829a2eafa69fd3fbf 100644 (file)
@@ -9,56 +9,6 @@ import tables
 
 """tests general mapper operations with an emphasis on selecting/loading"""
 
-user_result = [{'user_id' : 7}, {'user_id' : 8}, {'user_id' : 9}]
-user_address_result = [
-{'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
-{'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}])},
-{'user_id' : 9, 'addresses' : (Address, [])}
-]
-user_address_orders_result = [{'user_id' : 7, 
-    'addresses' : (Address, [{'address_id' : 1}]),
-    'orders' : (Order, [{'order_id' : 1}, {'order_id' : 3},{'order_id' : 5},])
-},
-
-{'user_id' : 8, 
-    'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}]),
-    'orders' : (Order, [])
-},
-{'user_id' : 9, 
-    'addresses' : (Address, []),
-    'orders' : (Order, [{'order_id' : 2},{'order_id' : 4}])
-}]
-
-user_all_result = [
-{'user_id' : 7, 
-    'addresses' : (Address, [{'address_id' : 1}]),
-    'orders' : (Order, [
-        {'order_id' : 1, 'items': (Item, [])}, 
-        {'order_id' : 3, 'items': (Item, [{'item_id':3, 'item_name':'item 3'}, {'item_id':4, 'item_name':'item 4'}, {'item_id':5, 'item_name':'item 5'}])},
-        {'order_id' : 5, 'items': (Item, [])},
-        ])
-},
-{'user_id' : 8, 
-    'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}]),
-    'orders' : (Order, [])
-},
-{'user_id' : 9, 
-    'addresses' : (Address, []),
-    'orders' : (Order, [
-        {'order_id' : 2, 'items': (Item, [{'item_id':1, 'item_name':'item 1'}, {'item_id':2, 'item_name':'item 2'}])},
-        {'order_id' : 4, 'items': (Item, [])}
-    ])
-}]
-
-item_keyword_result = [
-{'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
-{'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2, 'name':'red'}, {'keyword_id' : 5, 'name':'small'}, {'keyword_id' : 7, 'name':'square'}])},
-{'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 3,'name':'green'}, {'keyword_id' : 4,'name':'big'}, {'keyword_id' : 6,'name':'round'}])},
-{'item_id' : 4, 'keywords' : (Keyword, [])},
-{'item_id' : 5, 'keywords' : (Keyword, [])}
-]
-
-
 class MapperSuperTest(AssertMixin):
     def setUpAll(self):
         tables.create()
@@ -317,7 +267,20 @@ class MapperTest(MapperSuperTest):
 
         l = q.select((orderitems.c.item_name=='item 4') & q.join_to('items'))
         self.assert_result(l, User, user_result[0])
-        
+    
+    def testcustomjoin(self):
+        """test that the from_obj parameter to query.select() can be used
+        to totally replace the FROM parameters of the generated query."""
+        m = mapper(User, users, properties={
+            'orders':relation(mapper(Order, orders, properties={
+                'items':relation(mapper(Item, orderitems))
+            }))
+        })
+
+        q = create_session().query(m)
+        l = q.select((orderitems.c.item_name=='item 4'), from_obj=[users.join(orders).join(orderitems)])
+        self.assert_result(l, User, user_result[0])
+            
     def testorderby(self):
         """test ordering at the mapper and query level"""
         # TODO: make a unit test out of these various combinations
index 72870ec1eb7413a18846576c6ac844ffc65e82e0..f3b78a125c1eeeea056b17ac4b58a8c120c0898b 100644 (file)
@@ -3,9 +3,6 @@ from sqlalchemy import *
 import os
 import testbase
 
-__all__ = ['db', 'users', 'addresses', 'orders', 'orderitems', 'keywords', 'itemkeywords', 'userkeywords',
-            'User', 'Address', 'Order', 'Item', 'Keyword'
-        ]
 
 ECHO = testbase.echo
 db = testbase.db
@@ -129,6 +126,7 @@ def data():
         dict(keyword_id=7, item_id=2),
         dict(keyword_id=6, item_id=3)
     )
+
     
 class User(object):
     def __init__(self):
@@ -168,5 +166,54 @@ class Keyword(object):
         return "Keyword: %s/%s" % (repr(getattr(self, 'keyword_id', None)),repr(self.name))
 
 
+user_result = [{'user_id' : 7}, {'user_id' : 8}, {'user_id' : 9}]
+
+user_address_result = [
+{'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
+{'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}])},
+{'user_id' : 9, 'addresses' : (Address, [])}
+]
+
+user_address_orders_result = [{'user_id' : 7, 
+    'addresses' : (Address, [{'address_id' : 1}]),
+    'orders' : (Order, [{'order_id' : 1}, {'order_id' : 3},{'order_id' : 5},])
+    },
+    {'user_id' : 8, 
+        'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}]),
+        'orders' : (Order, [])
+    },
+    {'user_id' : 9, 
+        'addresses' : (Address, []),
+        'orders' : (Order, [{'order_id' : 2},{'order_id' : 4}])
+}]
+
+user_all_result = [
+{'user_id' : 7, 
+    'addresses' : (Address, [{'address_id' : 1}]),
+    'orders' : (Order, [
+        {'order_id' : 1, 'items': (Item, [])}, 
+        {'order_id' : 3, 'items': (Item, [{'item_id':3, 'item_name':'item 3'}, {'item_id':4, 'item_name':'item 4'}, {'item_id':5, 'item_name':'item 5'}])},
+        {'order_id' : 5, 'items': (Item, [])},
+        ])
+},
+{'user_id' : 8, 
+    'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}]),
+    'orders' : (Order, [])
+},
+{'user_id' : 9, 
+    'addresses' : (Address, []),
+    'orders' : (Order, [
+        {'order_id' : 2, 'items': (Item, [{'item_id':1, 'item_name':'item 1'}, {'item_id':2, 'item_name':'item 2'}])},
+        {'order_id' : 4, 'items': (Item, [])}
+    ])
+}]
+
+item_keyword_result = [
+{'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
+{'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2, 'name':'red'}, {'keyword_id' : 5, 'name':'small'}, {'keyword_id' : 7, 'name':'square'}])},
+{'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 3,'name':'green'}, {'keyword_id' : 4,'name':'big'}, {'keyword_id' : 6,'name':'round'}])},
+{'item_id' : 4, 'keywords' : (Keyword, [])},
+{'item_id' : 5, 'keywords' : (Keyword, [])}
+]
 
 #db.echo = True
index 702e00e83e9d92f10744878aa424ab465183f12a..d115d400ad00bedac8a8be183dbec716764c6ae0 100644 (file)
@@ -25,7 +25,6 @@ class Logger(object):
             local_stdout.write(msg)
     def flush(self):
         pass
-sys.stdout = Logger()    
 
 def echo_text(text):
     print text
@@ -363,10 +362,12 @@ parse_argv()
 
                     
 def runTests(suite):
+    sys.stdout = Logger()    
     runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2)
     runner.run(suite)
     
 def main():
+    
     if len(sys.argv[1:]):
         suite =unittest.TestLoader().loadTestsFromNames(sys.argv[1:], __import__('__main__'))
     else: