From: Mike Bayer Date: Sun, 22 Jul 2007 03:26:13 +0000 (+0000) Subject: filter_by([joinpath], ...) is gone. join([path], aliased=True) replaces it, all... X-Git-Tag: rel_0_4_6~55 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=da72aac46508867e3b90120c6a07c04bff926e37;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git filter_by([joinpath], ...) is gone. join([path], aliased=True) replaces it, all subsequent filter() criterion is converted against that alias; represents a much more flexible and consistent solution. needs some tweaks and can then work with self-referential loading too. --- diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 63620dbdd6..004024e58d 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -37,6 +37,7 @@ class Query(object): self._criterion = None self._column_aggregate = None self._joinpoint = self.mapper + self._aliases = None self._from_obj = [self.table] self._populate_existing = False self._version_check = False @@ -253,6 +254,12 @@ class Query(object): if criterion is not None and not isinstance(criterion, sql.ClauseElement): raise exceptions.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string") + + if self._aliases is not None: + adapter = sql_util.ClauseAdapter(self._aliases[0]) + for alias in self._aliases[1:]: + adapter.chain(sql_util.ClauseAdapter(alias)) + criterion = adapter.traverse(criterion, clone=True) q = self._clone() if q._criterion is not None: @@ -261,22 +268,15 @@ class Query(object): q._criterion = criterion return q - def filter_by(self, *args, **kwargs): + def filter_by(self, **kwargs): """apply the given filtering criterion to the query and return the newly resulting ``Query``.""" #import properties - if len(args) > 1: - raise exceptions.ArgumentError("filter_by() takes either zero positional arguments, or one scalar or list argument indicating a property search path.") - if len(args) == 1: - path = args[0] - (join, joinpoint, alias) = self._join_to(path, outerjoin=False, start=self.mapper, create_aliases=True) - clause = None - else: - alias = None - join = None - clause = None - joinpoint = self._joinpoint + alias = None + join = None + clause = None + joinpoint = self._joinpoint for key, value in kwargs.iteritems(): prop = joinpoint.get_property(key, resolve_synonyms=True) @@ -309,12 +309,14 @@ class Query(object): mapper = start alias = None + aliases = [] for key in util.to_list(keys): prop = mapper.get_property(key, resolve_synonyms=True) - if prop._is_self_referential(): - raise exceptions.InvalidRequestError("Self-referential query on '%s' property must be constructed manually using an Alias object for the related table." % str(prop)) + if prop._is_self_referential() and not create_aliases: + # TODO: create_aliases automatically ? probably + raise exceptions.InvalidRequestError("Self-referential query on '%s' property requries create_aliases=True argument." % str(prop)) # dont re-join to a table already in our from objects - if prop.select_table not in currenttables: + if prop.select_table not in currenttables or create_aliases: if outerjoin: if prop.secondary: clause = clause.outerjoin(prop.secondary, prop.get_join(mapper, primary=True, secondary=False)) @@ -326,11 +328,13 @@ class Query(object): if create_aliases: join = prop.get_join(mapper, primary=True, secondary=False) secondary_alias = prop.secondary.alias() + aliases.append(secondary_alias) if alias is not None: join = sql_util.ClauseAdapter(alias).traverse(join, clone=True) sql_util.ClauseAdapter(secondary_alias).traverse(join) clause = clause.join(secondary_alias, join) alias = prop.select_table.alias() + aliases.append(alias) join = prop.get_join(mapper, primary=False) join = sql_util.ClauseAdapter(secondary_alias).traverse(join, clone=True) sql_util.ClauseAdapter(alias).traverse(join) @@ -344,15 +348,16 @@ class Query(object): if alias is not None: join = sql_util.ClauseAdapter(alias).traverse(join, clone=True) alias = prop.select_table.alias() + aliases.append(alias) join = sql_util.ClauseAdapter(alias).traverse(join, clone=True) clause = clause.join(alias, join) else: clause = clause.join(prop.select_table, prop.get_join(mapper)) mapper = prop.mapper if create_aliases: - return (clause, mapper, alias) + return (clause, mapper, aliases) else: - return (clause, mapper) + return (clause, mapper, None) def _generative_col_aggregate(self, col, func): """apply the given aggregate function to the query and return the newly @@ -442,7 +447,7 @@ class Query(object): q._group_by = q._group_by + util.to_list(criterion) return q - def join(self, prop): + def join(self, prop, aliased=False): """create a join of this ``Query`` object's criterion to a relationship and return the newly resulting ``Query``. @@ -451,12 +456,13 @@ class Query(object): """ q = self._clone() - (clause, mapper) = self._join_to(prop, outerjoin=False, start=self.mapper) + (clause, mapper, aliases) = self._join_to(prop, outerjoin=False, start=self.mapper, create_aliases=aliased) q._from_obj = [clause] q._joinpoint = mapper + q._aliases = aliases return q - def outerjoin(self, prop): + def outerjoin(self, prop, aliased=False): """create a left outer join of this ``Query`` object's criterion to a relationship and return the newly resulting ``Query``. @@ -464,9 +470,10 @@ class Query(object): property names. """ q = self._clone() - (clause, mapper) = self._join_to(prop, outerjoin=True, start=self.mapper) + (clause, mapper, aliases) = self._join_to(prop, outerjoin=True, start=self.mapper, create_aliases=aliased) q._from_obj = [clause] q._joinpoint = mapper + q._aliases = aliases return q def reset_joinpoint(self): @@ -480,6 +487,7 @@ class Query(object): q = self._clone() q._joinpoint = q.mapper + q._aliases = None return q diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py index c2d1b28bc9..4dcdeac37d 100644 --- a/test/orm/eager_relations.py +++ b/test/orm/eager_relations.py @@ -477,5 +477,6 @@ class SelfReferentialEagerTest(testbase.ORMTest): Node(data='n13') ]) == d self.assert_sql_count(testbase.db, go, 1) + if __name__ == '__main__': testbase.main() diff --git a/test/orm/query.py b/test/orm/query.py index 084f86c348..6fa5c9644c 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -403,6 +403,32 @@ class JoinTest(QueryTest): result = create_session().query(User).select_from(users.join(oalias)).filter(oalias.c.description.in_("order 1", "order 2", "order 3")).join(['orders', 'items']).filter_by(id=4).all() assert [User(id=7, name='jack')] == result + def test_aliased(self): + """test automatic generation of aliased joins.""" + + sess = create_session() + + # test a basic aliasized path + q = sess.query(User).join('addresses', aliased=True).filter_by(email_address='jack@bean.com') + assert [User(id=7)] == q.all() + + q = sess.query(User).join('addresses', aliased=True).filter(Address.email_address=='jack@bean.com') + assert [User(id=7)] == q.all() + + # test two aliasized paths, one to 'orders' and the other to 'orders','items'. + # one row is returned because user 7 has order 3 and also has order 1 which has item 1 + # this tests a o2m join and a m2m join. + q = sess.query(User).join('orders', aliased=True).filter(Order.description=="order 3").join(['orders', 'items'], aliased=True).filter(Item.description=="item 1") + assert q.count() == 1 + assert [User(id=7)] == q.all() + + # test the control version - same joins but not aliased. rows are not returned because order 3 does not have item 1 + # addtionally by placing this test after the previous one, test that the "aliasing" step does not corrupt the + # join clauses that are cached by the relationship. + q = sess.query(User).join('orders').filter(Order.description=="order 3").join(['orders', 'items']).filter(Order.description=="item 1") + assert [] == q.all() + assert q.count() == 0 + class SynonymTest(QueryTest): keep_mappers = True @@ -586,32 +612,40 @@ class InstancesTest(QueryTest): assert q.all() == expected -class FilterByTest(QueryTest): - def test_aliased(self): - """test automatic generation of aliased joins using filter_by().""" - +# this test not working yet +class SelfReferentialTest(object): #testbase.ORMTest): + def define_tables(self, metadata): + global nodes + nodes = Table('nodes', metadata, + Column('id', Integer, primary_key=True), + Column('parent_id', Integer, ForeignKey('nodes.id')), + Column('data', String(30))) + + def test_join(self): + class Node(Base): + def append(self, node): + self.children.append(node) + + mapper(Node, nodes, properties={ + 'children':relation(Node, lazy=True, join_depth=3) + }) sess = create_session() - - # test a basic aliasized path - q = sess.query(User).filter_by(['addresses'], email_address='jack@bean.com') - assert [User(id=7)] == q.all() - - # test two aliasized paths, one to 'orders' and the other to 'orders','items'. - # one row is returned because user 7 has order 3 and also has order 1 which has item 1 - # this tests a o2m join and a m2m join. - q = sess.query(User).filter_by(['orders'], description="order 3").filter_by(['orders', 'items'], description="item 1") - assert q.count() == 1 - assert [User(id=7)] == q.all() - - # test the control version - same joins but not aliased. rows are not returned because order 3 does not have item 1 - # addtionally by placing this test after the previous one, test that the "aliasing" step does not corrupt the - # join clauses that are cached by the relationship. - q = sess.query(User).join('orders').filter_by(description="order 3").join(['orders', 'items']).filter_by(description="item 1") - assert [] == q.all() - assert q.count() == 0 - - - + n1 = Node(data='n1') + n1.append(Node(data='n11')) + n1.append(Node(data='n12')) + n1.append(Node(data='n13')) + n1.children[1].append(Node(data='n121')) + n1.children[1].append(Node(data='n122')) + n1.children[1].append(Node(data='n123')) + sess.save(n1) + sess.flush() + sess.clear() + + # TODO: the aliasing of the join in query._join_to has to limit the aliasing + # among local_side / remote_side (add local_side as an attribute on PropertyLoader) + # also implement this idea in EagerLoader + node = sess.query(Node).join('children', aliased=True).filter_by(data='n122').first() + assert node.data=='n12' if __name__ == '__main__': testbase.main()