]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
filter_by([joinpath], ...) is gone. join([path], aliased=True) replaces it, all...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 22 Jul 2007 03:26:13 +0000 (03:26 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 22 Jul 2007 03:26:13 +0000 (03:26 +0000)
against that alias; represents a much more flexible and consistent solution.
needs some tweaks and can then work with self-referential loading too.

lib/sqlalchemy/orm/query.py
test/orm/eager_relations.py
test/orm/query.py

index 63620dbdd6fa59edd5e1c81ead8b6645c15e1a0f..004024e58d3831baa1572179e05aa12402c65cc2 100644 (file)
@@ -37,6 +37,7 @@ class Query(object):
         self._criterion = None
         self._column_aggregate = None
         self._joinpoint = self.mapper
+        self._aliases = None
         self._from_obj = [self.table]
         self._populate_existing = False
         self._version_check = False
@@ -253,6 +254,12 @@ class Query(object):
             
         if criterion is not None and not isinstance(criterion, sql.ClauseElement):
             raise exceptions.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string")
+        
+        if self._aliases is not None:
+            adapter = sql_util.ClauseAdapter(self._aliases[0])
+            for alias in self._aliases[1:]:
+                adapter.chain(sql_util.ClauseAdapter(alias))
+            criterion = adapter.traverse(criterion, clone=True)
             
         q = self._clone()
         if q._criterion is not None:
@@ -261,22 +268,15 @@ class Query(object):
             q._criterion = criterion
         return q
 
-    def filter_by(self, *args, **kwargs):
+    def filter_by(self, **kwargs):
         """apply the given filtering criterion to the query and return the newly resulting ``Query``."""
 
         #import properties
         
-        if len(args) > 1:
-            raise exceptions.ArgumentError("filter_by() takes either zero positional arguments, or one scalar or list argument indicating a property search path.")
-        if len(args) == 1:
-            path = args[0]
-            (join, joinpoint, alias) = self._join_to(path, outerjoin=False, start=self.mapper, create_aliases=True)
-            clause = None
-        else:
-            alias = None
-            join = None
-            clause = None
-            joinpoint = self._joinpoint
+        alias = None
+        join = None
+        clause = None
+        joinpoint = self._joinpoint
 
         for key, value in kwargs.iteritems():
             prop = joinpoint.get_property(key, resolve_synonyms=True)
@@ -309,12 +309,14 @@ class Query(object):
             
         mapper = start
         alias = None
+        aliases = []
         for key in util.to_list(keys):
             prop = mapper.get_property(key, resolve_synonyms=True)
-            if prop._is_self_referential():
-                raise exceptions.InvalidRequestError("Self-referential query on '%s' property must be constructed manually using an Alias object for the related table." % str(prop))
+            if prop._is_self_referential() and not create_aliases:
+                # TODO: create_aliases automatically ? probably
+                raise exceptions.InvalidRequestError("Self-referential query on '%s' property requries create_aliases=True argument." % str(prop))
             # dont re-join to a table already in our from objects
-            if prop.select_table not in currenttables:
+            if prop.select_table not in currenttables or create_aliases:
                 if outerjoin:
                     if prop.secondary:
                         clause = clause.outerjoin(prop.secondary, prop.get_join(mapper, primary=True, secondary=False))
@@ -326,11 +328,13 @@ class Query(object):
                         if create_aliases:
                             join = prop.get_join(mapper, primary=True, secondary=False)
                             secondary_alias = prop.secondary.alias()
+                            aliases.append(secondary_alias)
                             if alias is not None:
                                 join = sql_util.ClauseAdapter(alias).traverse(join, clone=True)
                             sql_util.ClauseAdapter(secondary_alias).traverse(join)
                             clause = clause.join(secondary_alias, join)
                             alias = prop.select_table.alias()
+                            aliases.append(alias)
                             join = prop.get_join(mapper, primary=False)
                             join = sql_util.ClauseAdapter(secondary_alias).traverse(join, clone=True)
                             sql_util.ClauseAdapter(alias).traverse(join)
@@ -344,15 +348,16 @@ class Query(object):
                             if alias is not None:
                                 join = sql_util.ClauseAdapter(alias).traverse(join, clone=True)
                             alias = prop.select_table.alias()
+                            aliases.append(alias)
                             join = sql_util.ClauseAdapter(alias).traverse(join, clone=True)
                             clause = clause.join(alias, join)
                         else:
                             clause = clause.join(prop.select_table, prop.get_join(mapper))
             mapper = prop.mapper
         if create_aliases:
-            return (clause, mapper, alias)
+            return (clause, mapper, aliases)
         else:
-            return (clause, mapper)
+            return (clause, mapper, None)
 
     def _generative_col_aggregate(self, col, func):
         """apply the given aggregate function to the query and return the newly
@@ -442,7 +447,7 @@ class Query(object):
             q._group_by = q._group_by + util.to_list(criterion)
         return q
 
-    def join(self, prop):
+    def join(self, prop, aliased=False):
         """create a join of this ``Query`` object's criterion
         to a relationship and return the newly resulting ``Query``.
 
@@ -451,12 +456,13 @@ class Query(object):
         """
         
         q = self._clone()
-        (clause, mapper) = self._join_to(prop, outerjoin=False, start=self.mapper)
+        (clause, mapper, aliases) = self._join_to(prop, outerjoin=False, start=self.mapper, create_aliases=aliased)
         q._from_obj = [clause]
         q._joinpoint = mapper
+        q._aliases = aliases
         return q
 
-    def outerjoin(self, prop):
+    def outerjoin(self, prop, aliased=False):
         """create a left outer join of this ``Query`` object's criterion
         to a relationship and return the newly resulting ``Query``.
         
@@ -464,9 +470,10 @@ class Query(object):
         property names.
         """
         q = self._clone()
-        (clause, mapper) = self._join_to(prop, outerjoin=True, start=self.mapper)
+        (clause, mapper, aliases) = self._join_to(prop, outerjoin=True, start=self.mapper, create_aliases=aliased)
         q._from_obj = [clause]
         q._joinpoint = mapper
+        q._aliases = aliases
         return q
 
     def reset_joinpoint(self):
@@ -480,6 +487,7 @@ class Query(object):
 
         q = self._clone()
         q._joinpoint = q.mapper
+        q._aliases = None
         return q
 
 
index c2d1b28bc97f2635ef7c9170adadb132028e162f..4dcdeac37d628cc76dd61d8d647ac4455b4d147a 100644 (file)
@@ -477,5 +477,6 @@ class SelfReferentialEagerTest(testbase.ORMTest):
                 Node(data='n13')
             ]) == d
         self.assert_sql_count(testbase.db, go, 1)
+
 if __name__ == '__main__':
     testbase.main()
index 084f86c348864ac100921f8d8e20fada202a6376..6fa5c9644ca96939d472f8e4d704b15fd2411341 100644 (file)
@@ -403,6 +403,32 @@ class JoinTest(QueryTest):
         result = create_session().query(User).select_from(users.join(oalias)).filter(oalias.c.description.in_("order 1", "order 2", "order 3")).join(['orders', 'items']).filter_by(id=4).all()
         assert [User(id=7, name='jack')] == result
 
+    def test_aliased(self):
+        """test automatic generation of aliased joins."""
+
+        sess = create_session()
+
+        # test a basic aliasized path
+        q = sess.query(User).join('addresses', aliased=True).filter_by(email_address='jack@bean.com')
+        assert [User(id=7)] == q.all()
+
+        q = sess.query(User).join('addresses', aliased=True).filter(Address.email_address=='jack@bean.com')
+        assert [User(id=7)] == q.all()
+
+        # test two aliasized paths, one to 'orders' and the other to 'orders','items'.
+        # one row is returned because user 7 has order 3 and also has order 1 which has item 1
+        # this tests a o2m join and a m2m join.
+        q = sess.query(User).join('orders', aliased=True).filter(Order.description=="order 3").join(['orders', 'items'], aliased=True).filter(Item.description=="item 1")
+        assert q.count() == 1
+        assert [User(id=7)] == q.all()
+
+        # test the control version - same joins but not aliased.  rows are not returned because order 3 does not have item 1
+        # addtionally by placing this test after the previous one, test that the "aliasing" step does not corrupt the
+        # join clauses that are cached by the relationship.
+        q = sess.query(User).join('orders').filter(Order.description=="order 3").join(['orders', 'items']).filter(Order.description=="item 1")
+        assert [] == q.all()
+        assert q.count() == 0
+
 
 class SynonymTest(QueryTest):
     keep_mappers = True
@@ -586,32 +612,40 @@ class InstancesTest(QueryTest):
             
         assert q.all() == expected
 
-class FilterByTest(QueryTest):
-    def test_aliased(self):
-        """test automatic generation of aliased joins using filter_by()."""
-        
+# this test not working yet
+class SelfReferentialTest(object): #testbase.ORMTest):
+    def define_tables(self, metadata):
+        global nodes
+        nodes = Table('nodes', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('parent_id', Integer, ForeignKey('nodes.id')),
+            Column('data', String(30)))
+
+    def test_join(self):
+        class Node(Base):
+            def append(self, node):
+                self.children.append(node)
+
+        mapper(Node, nodes, properties={
+            'children':relation(Node, lazy=True, join_depth=3)
+        })
         sess = create_session()
-        
-        # test a basic aliasized path
-        q = sess.query(User).filter_by(['addresses'], email_address='jack@bean.com')
-        assert [User(id=7)] == q.all()
-
-        # test two aliasized paths, one to 'orders' and the other to 'orders','items'.
-        # one row is returned because user 7 has order 3 and also has order 1 which has item 1
-        # this tests a o2m join and a m2m join.
-        q = sess.query(User).filter_by(['orders'], description="order 3").filter_by(['orders', 'items'], description="item 1")
-        assert q.count() == 1
-        assert [User(id=7)] == q.all()
-        
-        # test the control version - same joins but not aliased.  rows are not returned because order 3 does not have item 1
-        # addtionally by placing this test after the previous one, test that the "aliasing" step does not corrupt the
-        # join clauses that are cached by the relationship.
-        q = sess.query(User).join('orders').filter_by(description="order 3").join(['orders', 'items']).filter_by(description="item 1")
-        assert [] == q.all()
-        assert q.count() == 0
-        
-        
-        
+        n1 = Node(data='n1')
+        n1.append(Node(data='n11'))
+        n1.append(Node(data='n12'))
+        n1.append(Node(data='n13'))
+        n1.children[1].append(Node(data='n121'))
+        n1.children[1].append(Node(data='n122'))
+        n1.children[1].append(Node(data='n123'))
+        sess.save(n1)
+        sess.flush()
+        sess.clear()
+        
+        # TODO: the aliasing of the join in query._join_to has to limit the aliasing
+        # among local_side / remote_side (add local_side as an attribute on PropertyLoader)
+        # also implement this idea in EagerLoader
+        node = sess.query(Node).join('children', aliased=True).filter_by(data='n122').first()
+        assert node.data=='n12'
 
 if __name__ == '__main__':
     testbase.main()