From: Mike Bayer Date: Sun, 10 Jun 2007 23:53:03 +0000 (+0000) Subject: - added "aliased joins" feature to query.filter_by() X-Git-Tag: rel_0_4_6~211 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cd358b8c9df4eb77b4c09ce12a340fb8bb67546a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - added "aliased joins" feature to query.filter_by() - further work on modernizing/cleaning up unit tests --- diff --git a/CHANGES b/CHANGES index 9952137830..1c2cc8d604 100644 --- a/CHANGES +++ b/CHANGES @@ -7,6 +7,11 @@ new replacements. - query.list() replaced with query.all() - removed ancient query.select_by_attributename() capability. + - added "aliased joins" positional argument to the front of + filter_by(). this allows auto-creation of joins that are aliased + locally to the individual filter_by() call. This allows the + auto-construction of joins which cross the same paths but + are querying divergent criteria. - along with recent speedups to ResultProxy, total number of function calls significantly reduced for large loads. test/perf/masseagerload.py reports 0.4 as having the fewest number diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index d4318049b4..a7bf73b80a 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -255,22 +255,37 @@ class Query(object): import properties - clause = None + if len(args) > 1: + raise exceptions.InvalidRequestError("filter_by() takes either zero positional arguments, or one scalar or list argument indicating a property search path.") + if len(args) == 1: + path = args[0] + (join, joinpoint, alias) = self._join_to(path, outerjoin=False, start=self.mapper, create_aliases=True) + clause = None + else: + alias = None + join = None + clause = None + joinpoint = self._joinpoint for key, value in kwargs.iteritems(): - prop = self._joinpoint.props[key] + prop = joinpoint.props[key] if isinstance(prop, properties.PropertyLoader): c = self._with_lazy_criterion(value, prop, True) # & self.join_via(keys[:-1]) - use aliasized join feature else: c = prop.compare(value) # & self.join_via(keys) - use aliasized join feature + if alias is not None: + sql_util.ClauseAdapter(alias).traverse(c) if clause is None: clause = c else: clause &= c - return self.filter(clause) + if join is not None: + return self.select_from(join).filter(clause) + else: + return self.filter(clause) - def _join_to(self, prop, outerjoin=False, start=None): + def _join_to(self, prop, outerjoin=False, start=None, create_aliases=False): if start is None: start = self._joinpoint @@ -293,6 +308,7 @@ class Query(object): FindJoinedTables().traverse(clause) mapper = start + alias = None for key in keys: prop = mapper.props[key] if prop._is_self_referential(): @@ -307,12 +323,36 @@ class Query(object): clause = clause.outerjoin(prop.select_table, prop.get_join(mapper)) else: if prop.secondary: - clause = clause.join(prop.secondary, prop.get_join(mapper, primary=True, secondary=False)) - clause = clause.join(prop.select_table, prop.get_join(mapper, primary=False)) + if create_aliases: + join = prop.get_join(mapper, primary=True, secondary=False).copy_container() + secondary_alias = prop.secondary.alias() + if alias is not None: + sql_util.ClauseAdapter(alias).traverse(join) + sql_util.ClauseAdapter(secondary_alias).traverse(join) + clause = clause.join(secondary_alias, join) + alias = prop.select_table.alias() + join = prop.get_join(mapper, primary=False).copy_container() + sql_util.ClauseAdapter(secondary_alias).traverse(join) + sql_util.ClauseAdapter(alias).traverse(join) + clause = clause.join(alias, join) + else: + clause = clause.join(prop.secondary, prop.get_join(mapper, primary=True, secondary=False)) + clause = clause.join(prop.select_table, prop.get_join(mapper, primary=False)) else: - clause = clause.join(prop.select_table, prop.get_join(mapper)) + if create_aliases: + join = prop.get_join(mapper).copy_container() + if alias is not None: + sql_util.ClauseAdapter(alias).traverse(join) + alias = prop.select_table.alias() + sql_util.ClauseAdapter(alias).traverse(join) + clause = clause.join(alias, join) + else: + clause = clause.join(prop.select_table, prop.get_join(mapper)) mapper = prop.mapper - return (clause, mapper) + if create_aliases: + return (clause, mapper, alias) + else: + return (clause, mapper) def _generative_col_aggregate(self, col, func): """apply the given aggregate function to the query and return the newly diff --git a/test/orm/alltests.py b/test/orm/alltests.py index 395b402bdb..133d517661 100644 --- a/test/orm/alltests.py +++ b/test/orm/alltests.py @@ -7,6 +7,7 @@ def suite(): modules_to_test = ( 'orm.attributes', 'orm.query', + 'orm.lazy_relations', 'orm.mapper', 'orm.generative', 'orm.lazytest1', diff --git a/test/orm/fixtures.py b/test/orm/fixtures.py new file mode 100644 index 0000000000..b79fdb5ba5 --- /dev/null +++ b/test/orm/fixtures.py @@ -0,0 +1,152 @@ +from sqlalchemy import * + +class Base(object): + def __init__(self, **kwargs): + for k in kwargs: + setattr(self, k, kwargs[k]) + + def __ne__(self, other): + return not self.__eq__(other) + + def __eq__(self, other): + """'passively' compare this object to another. + + only look at attributes that are present on the source object. + + """ + # use __dict__ to avoid instrumented properties + for attr in self.__dict__.keys(): + if attr[0] == '_': + continue + value = getattr(self, attr) + if hasattr(value, '__iter__') and not isinstance(value, basestring): + if len(value) == 0: + continue + for (us, them) in zip(value, getattr(other, attr)): + if us != them: + return False + else: + continue + else: + if value is not None: + if value != getattr(other, attr): + return False + else: + return True + +class User(Base):pass +class Order(Base):pass +class Item(Base):pass +class Keyword(Base):pass +class Address(Base):pass + +metadata = MetaData() + +users = Table('users', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(30), nullable=False)) + +orders = Table('orders', metadata, + Column('id', Integer, primary_key=True), + Column('user_id', None, ForeignKey('users.id')), + Column('address_id', None, ForeignKey('addresses.id')), + Column('description', String(30)), + Column('isopen', Integer) + ) + +addresses = Table('addresses', metadata, + Column('id', Integer, primary_key=True), + Column('user_id', None, ForeignKey('users.id')), + Column('email_address', String(50), nullable=False)) + +items = Table('items', metadata, + Column('id', Integer, primary_key=True), + Column('description', String(30), nullable=False) + ) + +order_items = Table('order_items', metadata, + Column('item_id', None, ForeignKey('items.id')), + Column('order_id', None, ForeignKey('orders.id'))) + +item_keywords = Table('item_keywords', metadata, + Column('item_id', None, ForeignKey('items.id')), + Column('keyword_id', None, ForeignKey('keywords.id'))) + +keywords = Table('keywords', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(30), nullable=False) + ) + +def install_fixture_data(): + users.insert().execute( + dict(id = 7, name = 'jack'), + dict(id = 8, name = 'ed'), + dict(id = 9, name = 'fred'), + dict(id = 10, name = 'chuck'), + + ) + addresses.insert().execute( + dict(id = 1, user_id = 7, email_address = "jack@bean.com"), + dict(id = 2, user_id = 8, email_address = "ed@wood.com"), + dict(id = 3, user_id = 8, email_address = "ed@bettyboop.com"), + dict(id = 4, user_id = 8, email_address = "ed@lala.com"), + dict(id = 5, user_id = 9, email_address = "fred@fred.com"), + ) + orders.insert().execute( + dict(id = 1, user_id = 7, description = 'order 1', isopen=0, address_id=1), + dict(id = 2, user_id = 9, description = 'order 2', isopen=0, address_id=4), + dict(id = 3, user_id = 7, description = 'order 3', isopen=1, address_id=1), + dict(id = 4, user_id = 9, description = 'order 4', isopen=1, address_id=4), + dict(id = 5, user_id = 7, description = 'order 5', isopen=0, address_id=1) + ) + items.insert().execute( + dict(id=1, description='item 1'), + dict(id=2, description='item 2'), + dict(id=3, description='item 3'), + dict(id=4, description='item 4'), + dict(id=5, description='item 5'), + ) + order_items.insert().execute( + dict(item_id=1, order_id=1), + dict(item_id=2, order_id=1), + dict(item_id=3, order_id=1), + + dict(item_id=1, order_id=2), + dict(item_id=2, order_id=2), + dict(item_id=3, order_id=2), + + dict(item_id=3, order_id=3), + dict(item_id=4, order_id=3), + dict(item_id=5, order_id=3), + + dict(item_id=1, order_id=4), + dict(item_id=5, order_id=4), + + dict(item_id=5, order_id=5), + ) + keywords.insert().execute( + dict(id=1, name='blue'), + dict(id=2, name='red'), + dict(id=3, name='green'), + dict(id=4, name='big'), + dict(id=5, name='small'), + dict(id=6, name='round'), + dict(id=7, name='square') + ) + + # this many-to-many table has the keywords inserted + # in primary key order, to appease the unit tests. + # this is because postgres, oracle, and sqlite all support + # true insert-order row id, but of course our pal MySQL does not, + # so the best it can do is order by, well something, so there you go. + item_keywords.insert().execute( + dict(keyword_id=2, item_id=1), + dict(keyword_id=2, item_id=2), + dict(keyword_id=4, item_id=1), + dict(keyword_id=6, item_id=1), + dict(keyword_id=5, item_id=2), + dict(keyword_id=3, item_id=3), + dict(keyword_id=4, item_id=3), + dict(keyword_id=7, item_id=2), + dict(keyword_id=6, item_id=3) + ) diff --git a/test/orm/inheritance/polymorph2.py b/test/orm/inheritance/polymorph2.py index ac47ff33ce..a678994493 100644 --- a/test/orm/inheritance/polymorph2.py +++ b/test/orm/inheritance/polymorph2.py @@ -205,20 +205,8 @@ class RelationTest3(testbase.ORMTest): Column('data', String(30)) ) - def testrelationonbaseclass_j1_nodata(self): - self.do_test("join1", False) - def testrelationonbaseclass_j2_nodata(self): - self.do_test("join2", False) - def testrelationonbaseclass_j1_data(self): - self.do_test("join1", True) - def testrelationonbaseclass_j2_data(self): - self.do_test("join2", True) - def testrelationonbaseclass_j3_nodata(self): - self.do_test("join3", False) - def testrelationonbaseclass_j3_data(self): - self.do_test("join3", True) - - def do_test(self, jointype="join1", usedata=False): +def generate_test(jointype="join1", usedata=False): + def do_test(self): class Person(AttrSettable): pass class Manager(Person): @@ -239,12 +227,14 @@ class RelationTest3(testbase.ORMTest): 'manager':join(people, managers, people.c.person_id==managers.c.person_id), 'person':people.select(people.c.type=='person') }, None) - elif jointype == "join3": + elif jointype == 'join3': + poly_union = people.outerjoin(managers) + elif jointype == "join4": poly_union=None - + if usedata: mapper(Data, data) - + mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager') if usedata: mapper(Person, people, select_table=poly_union, polymorphic_identity='person', polymorphic_on=people.c.type, @@ -275,7 +265,7 @@ class RelationTest3(testbase.ORMTest): sess.save(m) sess.save(p) sess.flush() - + sess.clear() p = sess.query(Person).get(p.person_id) p2 = sess.query(Person).get(p2.person_id) @@ -288,7 +278,15 @@ class RelationTest3(testbase.ORMTest): if usedata: assert p.data.data == 'ps data' assert m.data.data == 'ms data' + + do_test.__name__ = 'test_relationonbaseclass_%s_%s' % (jointype, data and "nodata" or "data") + return do_test +for jointype in ["join1", "join2", "join3", "join4"]: + for data in (True, False): + func = generate_test(jointype, data) + setattr(RelationTest3, func.__name__, func) + class RelationTest4(testbase.ORMTest): def define_tables(self, metadata): @@ -723,7 +721,7 @@ class GenerativeTest(testbase.AssertMixin): # test these twice because theres caching involved, as well previous issues that modified the polymorphic union for x in range(0, 2): - r = session.query(Person).filter_by(people.c.name.like('%2')).join('status').filter_by(name="active") + r = session.query(Person).filter(people.c.name.like('%2')).join('status').filter_by(name="active") assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]" r = session.query(Engineer).join('status').filter(people.c.name.in_('E2', 'E3', 'E4', 'M4', 'M2', 'M1') & (status.c.name=="active")) assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]" diff --git a/test/orm/lazy_relations.py b/test/orm/lazy_relations.py new file mode 100644 index 0000000000..e0a6ceeed4 --- /dev/null +++ b/test/orm/lazy_relations.py @@ -0,0 +1,182 @@ +from sqlalchemy import * +from sqlalchemy.orm import * +import testbase + +from fixtures import * +from query import QueryTest + +class LazyTest(QueryTest): + keep_mappers = False + + def setup_mappers(self): + pass + + def test_basic(self): + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), lazy=True) + }) + sess = create_session() + q = sess.query(User) + assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(users.c.id == 7).all() + + def test_bindstosession(self): + """test that lazy loaders use the mapper's contextual session if the parent instance + is not in a session, and that an error is raised if no contextual session""" + + from sqlalchemy.ext.sessioncontext import SessionContext + ctx = SessionContext(create_session) + m = mapper(User, users, properties = dict( + addresses = relation(mapper(Address, addresses, extension=ctx.mapper_extension), lazy=True) + ), extension=ctx.mapper_extension) + q = ctx.current.query(m) + u = q.filter(users.c.id == 7).first() + ctx.current.expunge(u) + assert User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')]) == u + + clear_mappers() + + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), lazy=True) + }) + try: + sess = create_session() + q = sess.query(User) + u = q.filter(users.c.id == 7).first() + sess.expunge(u) + assert User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')]) == u + except exceptions.InvalidRequestError, err: + assert "not bound to a Session, and no contextual session" in str(err) + + def test_orderby(self): + mapper(User, users, properties = { + 'addresses':relation(mapper(Address, addresses), lazy=True, order_by=addresses.c.email_address), + }) + q = create_session().query(User) + assert [ + User(id=7, addresses=[ + Address(id=1) + ]), + User(id=8, addresses=[ + Address(id=3, email_address='ed@bettyboop.com'), + Address(id=4, email_address='ed@lala.com'), + Address(id=2, email_address='ed@wood.com') + ]), + User(id=9, addresses=[ + Address(id=5) + ]), + User(id=10, addresses=[]) + ] == q.all() + + def test_orderby_secondary(self): + """tests that a regular mapper select on a single table can order by a relation to a second table""" + + mapper(Address, addresses) + + mapper(User, users, properties = dict( + addresses = relation(Address, lazy=True), + )) + q = create_session().query(User) + l = q.filter(users.c.id==addresses.c.user_id).order_by(addresses.c.email_address).all() + assert [ + User(id=8, addresses=[ + Address(id=2, email_address='ed@wood.com'), + Address(id=3, email_address='ed@bettyboop.com'), + Address(id=4, email_address='ed@lala.com'), + ]), + User(id=9, addresses=[ + Address(id=5) + ]), + User(id=7, addresses=[ + Address(id=1) + ]), + ] == l + + def test_orderby_desc(self): + mapper(Address, addresses) + + mapper(User, users, properties = dict( + addresses = relation(Address, lazy=True, order_by=[desc(addresses.c.email_address)]), + )) + sess = create_session() + assert [ + User(id=7, addresses=[ + Address(id=1) + ]), + User(id=8, addresses=[ + Address(id=2, email_address='ed@wood.com'), + Address(id=4, email_address='ed@lala.com'), + Address(id=3, email_address='ed@bettyboop.com'), + ]), + User(id=9, addresses=[ + Address(id=5) + ]), + User(id=10, addresses=[]) + ] == sess.query(User).all() + + def test_no_orphan(self): + """test that a lazily loaded child object is not marked as an orphan""" + + mapper(User, users, properties={ + 'addresses':relation(Address, cascade="all,delete-orphan", lazy=True) + }) + mapper(Address, addresses) + + sess = create_session() + user = sess.query(User).get(7) + assert getattr(User, 'addresses').hasparent(user.addresses[0], optimistic=True) + assert not class_mapper(Address)._is_orphan(user.addresses[0]) + + def test_limit(self): + """test limit operations combined with lazy-load relationships.""" + + mapper(Item, items) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, lazy=True) + }) + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), lazy=True), + 'orders':relation(Order, lazy=True) + }) + + sess = create_session() + q = sess.query(User) + + if testbase.db.engine.name == 'mssql': + l = q.limit(2).all() + assert self.user_all_result[:2] == l + else: + l = q.limit(2).offset(1).all() + print l + print self.user_all_result[1:3] + assert self.user_all_result[1:3] == l + + def test_distinct(self): + mapper(Item, items) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, lazy=True) + }) + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), lazy=True), + 'orders':relation(Order, lazy=True) + }) + + sess = create_session() + q = sess.query(User) + + # use a union all to get a lot of rows to join against + u2 = users.alias('u2') + s = union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u') + print [key for key in s.c.keys()] + l = q.filter(s.c.u2_id==User.c.id).distinct().all() + assert self.user_all_result == l + + def test_onetoone(self): + mapper(User, users, properties = dict( + address = relation(mapper(Address, addresses), lazy=True, uselist=False) + )) + q = create_session().query(User) + l = q.filter(users.c.id == 7).all() + assert [User(id=7, address=Address(id=1))] == l + +if __name__ == '__main__': + testbase.main() diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 4312598267..97de217120 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -848,31 +848,7 @@ class NoLoadTest(MapperSuperTest): ) -class LazyTest(MapperSuperTest): - - def testbasic(self): - """tests a basic one-to-many lazy load""" - m = mapper(User, users, properties = dict( - addresses = relation(mapper(Address, addresses), lazy = True) - )) - q = create_session().query(m) - l = q.select(users.c.user_id == 7) - self.assert_result(l, User, - {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])}, - ) - - def testbindstosession(self): - ctx = SessionContext(create_session) - m = mapper(User, users, properties = dict( - addresses = relation(mapper(Address, addresses, extension=ctx.mapper_extension), lazy=True) - ), extension=ctx.mapper_extension) - q = ctx.current.query(m) - u = q.filter(users.c.user_id == 7).selectfirst() - ctx.current.expunge(u) - self.assert_result([u], User, - {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])}, - ) - +class MapperExtensionTest(MapperSuperTest): def testcreateinstance(self): class Ext(MapperExtension): def create_instance(self, *args, **kwargs): @@ -885,106 +861,9 @@ class LazyTest(MapperSuperTest): q = create_session().query(m) l = q.select(); self.assert_result(l, User, *user_address_result) - - def testorderby(self): - m = mapper(Address, addresses) - - m = mapper(User, users, properties = dict( - addresses = relation(m, lazy = True, order_by=addresses.c.email_address), - )) - q = create_session().query(m) - l = q.select() - - self.assert_result(l, User, - {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])}, - {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@bettyboop.com'}, {'email_address':'ed@lala.com'}, {'email_address':'ed@wood.com'}])}, - {'user_id' : 9, 'addresses' : (Address, [])} - ) - - def testorderby_select(self): - """tests that a regular mapper select on a single table can order by a relation to a second table""" - m = mapper(Address, addresses) - - m = mapper(User, users, properties = dict( - addresses = relation(m, lazy = True), - )) - q = create_session().query(m) - l = q.select(users.c.user_id==addresses.c.user_id, order_by=addresses.c.email_address) - - self.assert_result(l, User, - {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@wood.com'}, {'email_address':'ed@bettyboop.com'}, {'email_address':'ed@lala.com'}, ])}, - {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])}, - ) - - def testorderby_desc(self): - m = mapper(Address, addresses) - - m = mapper(User, users, properties = dict( - addresses = relation(m, lazy = True, order_by=[desc(addresses.c.email_address)]), - )) - q = create_session().query(m) - l = q.select() - - self.assert_result(l, User, - {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])}, - {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@wood.com'}, {'email_address':'ed@lala.com'}, {'email_address':'ed@bettyboop.com'}])}, - {'user_id' : 9, 'addresses' : (Address, [])}, - ) - - def testorphanstate(self): - """test that a lazily loaded child object is not marked as an orphan""" - m = mapper(User, users, properties={ - 'addresses':relation(Address, cascade="all,delete-orphan", lazy=True) - }) - mapper(Address, addresses) - - q = create_session().query(m) - user = q.get(7) - assert getattr(User, 'addresses').hasparent(user.addresses[0], optimistic=True) - assert not class_mapper(Address)._is_orphan(user.addresses[0]) - - def testlimit(self): - ordermapper = mapper(Order, orders, properties = dict( - items = relation(mapper(Item, orderitems), lazy = True) - )) - - m = mapper(User, users, properties = dict( - addresses = relation(mapper(Address, addresses), lazy = True), - orders = relation(ordermapper, primaryjoin = users.c.user_id==orders.c.user_id, lazy = True), - )) - sess= create_session() - q = sess.query(m) - - if db.engine.name == 'mssql': - l = q.select(limit=2) - self.assert_result(l, User, *user_all_result[:2]) - else: - l = q.select(limit=2, offset=1) - self.assert_result(l, User, *user_all_result[1:3]) - - # use a union all to get a lot of rows to join against - u2 = users.alias('u2') - s = union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u') - print [key for key in s.c.keys()] - l = q.select(s.c.u2_user_id==User.c.user_id, distinct=True) - self.assert_result(l, User, *user_all_result) - - sess.clear() - clear_mappers() - m = mapper(Item, orderitems, properties = dict( - keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy = True), - )) - - l = sess.query(m).select((Item.c.item_name=='item 2') | (Item.c.item_name=='item 5') | (Item.c.item_name=='item 3'), order_by=[Item.c.item_id], limit=2) - self.assert_result(l, Item, *[item_keyword_result[1], item_keyword_result[2]]) + +class LazyTest(MapperSuperTest): - def testonetoone(self): - m = mapper(User, users, properties = dict( - address = relation(mapper(Address, addresses), lazy = True, uselist = False) - )) - q = create_session().query(m) - l = q.select(users.c.user_id == 7) - self.assert_result(l, User, {'user_id':7, 'address' : (Address, {'address_id':1})}) def testbackwardsonetoone(self): m = mapper(Address, addresses, properties = dict( diff --git a/test/orm/query.py b/test/orm/query.py index b7b906466d..9527b25d83 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -1,6 +1,7 @@ from sqlalchemy import * from sqlalchemy.orm import * import testbase +from fixtures import * class Base(object): def __init__(self, **kwargs): @@ -42,132 +43,19 @@ class QueryTest(testbase.ORMTest): def setUpAll(self): super(QueryTest, self).setUpAll() - self.install_fixture_data() + install_fixture_data() self.setup_mappers() - - def define_tables(self, metadata): - global users, orders, addresses, items, order_items, item_keywords, keywords - - users = Table('users', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(30), nullable=False)) - - orders = Table('orders', metadata, - Column('id', Integer, primary_key=True), - Column('user_id', None, ForeignKey('users.id')), - Column('address_id', None, ForeignKey('addresses.id')), - Column('description', String(30)), - Column('isopen', Integer) - ) - - addresses = Table('addresses', metadata, - Column('id', Integer, primary_key=True), - Column('user_id', None, ForeignKey('users.id')), - Column('email_address', String(50), nullable=False)) - - items = Table('items', metadata, - Column('id', Integer, primary_key=True), - Column('description', String(30), nullable=False) - ) - - order_items = Table('order_items', metadata, - Column('item_id', None, ForeignKey('items.id')), - Column('order_id', None, ForeignKey('orders.id'))) - - item_keywords = Table('item_keywords', metadata, - Column('item_id', None, ForeignKey('items.id')), - Column('keyword_id', None, ForeignKey('keywords.id'))) - - keywords = Table('keywords', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(30), nullable=False) - ) - - def install_fixture_data(self): - users.insert().execute( - dict(id = 7, name = 'jack'), - dict(id = 8, name = 'ed'), - dict(id = 9, name = 'fred'), - dict(id = 10, name = 'chuck'), - - ) - addresses.insert().execute( - dict(id = 1, user_id = 7, email_address = "jack@bean.com"), - dict(id = 2, user_id = 8, email_address = "ed@wood.com"), - dict(id = 3, user_id = 8, email_address = "ed@bettyboop.com"), - dict(id = 4, user_id = 8, email_address = "ed@lala.com"), - dict(id = 5, user_id = 9, email_address = "fred@fred.com"), - ) - orders.insert().execute( - dict(id = 1, user_id = 7, description = 'order 1', isopen=0, address_id=1), - dict(id = 2, user_id = 9, description = 'order 2', isopen=0, address_id=4), - dict(id = 3, user_id = 7, description = 'order 3', isopen=1, address_id=1), - dict(id = 4, user_id = 9, description = 'order 4', isopen=1, address_id=4), - dict(id = 5, user_id = 7, description = 'order 5', isopen=0, address_id=1) - ) - items.insert().execute( - dict(id=1, description='item 1'), - dict(id=2, description='item 2'), - dict(id=3, description='item 3'), - dict(id=4, description='item 4'), - dict(id=5, description='item 5'), - ) - order_items.insert().execute( - dict(item_id=1, order_id=1), - dict(item_id=2, order_id=1), - dict(item_id=3, order_id=1), - - dict(item_id=1, order_id=2), - dict(item_id=2, order_id=2), - dict(item_id=3, order_id=2), - dict(item_id=2, order_id=2), - - dict(item_id=3, order_id=3), - dict(item_id=4, order_id=3), - dict(item_id=5, order_id=3), - - dict(item_id=5, order_id=4), - dict(item_id=1, order_id=4), - - dict(item_id=5, order_id=5), - ) - keywords.insert().execute( - dict(id=1, name='blue'), - dict(id=2, name='red'), - dict(id=3, name='green'), - dict(id=4, name='big'), - dict(id=5, name='small'), - dict(id=6, name='round'), - dict(id=7, name='square') - ) - - # this many-to-many table has the keywords inserted - # in primary key order, to appease the unit tests. - # this is because postgres, oracle, and sqlite all support - # true insert-order row id, but of course our pal MySQL does not, - # so the best it can do is order by, well something, so there you go. - item_keywords.insert().execute( - dict(keyword_id=2, item_id=1), - dict(keyword_id=2, item_id=2), - dict(keyword_id=4, item_id=1), - dict(keyword_id=6, item_id=1), - dict(keyword_id=5, item_id=2), - dict(keyword_id=3, item_id=3), - dict(keyword_id=4, item_id=3), - dict(keyword_id=7, item_id=2), - dict(keyword_id=6, item_id=3) - ) - - def setup_mappers(self): - global User, Order, Item, Keyword, Address + def tearDownAll(self): + clear_mappers() + super(QueryTest, self).tearDownAll() + + def define_tables(self, meta): + # a slight dirty trick here. + meta.tables = metadata.tables + metadata.connect(meta.engine) - class User(Base):pass - class Order(Base):pass - class Item(Base):pass - class Keyword(Base):pass - class Address(Base):pass - + def setup_mappers(self): mapper(User, users, properties={ 'addresses':relation(Address), 'orders':relation(Order, backref='user'), # o2m, m2o @@ -199,6 +87,30 @@ class QueryTest(testbase.ORMTest): User(id=10, addresses=[]) ] + @property + def user_all_result(self): + return [ + User(id=7, addresses=[ + Address(id=1) + ], orders=[ + Order(description='order 1', items=[Item(description='item 1'), Item(description='item 2'), Item(description='item 3')]), + Order(description='order 3'), + Order(description='order 5'), + ]), + User(id=8, addresses=[ + Address(id=2), + Address(id=3), + Address(id=4) + ]), + User(id=9, addresses=[ + Address(id=5) + ], orders=[ + Order(description='order 2', items=[Item(description='item 1'), Item(description='item 2'), Item(description='item 3')]), + Order(description='order 4', items=[Item(description='item 1'), Item(description='item 5')]), + ]), + User(id=10, addresses=[]) + ] + class GetTest(QueryTest): def test_get(self): s = create_session() @@ -289,6 +201,7 @@ class ParentTest(QueryTest): class JoinTest(QueryTest): def test_overlapping_paths(self): + # load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack) result = create_session().query(User).join(['orders', 'items']).filter_by(id=3).join(['orders','address']).filter_by(id=1).all() assert [User(id=7, name='jack')] == result @@ -378,7 +291,7 @@ class InstancesTest(QueryTest): def test_multi_columns(self): sess = create_session() - (user7, user8, user9, user10) = sess.query(User).select() + (user7, user8, user9, user10) = sess.query(User).all() expected = [(user7, 1), (user8, 3), (user9, 1), @@ -398,7 +311,7 @@ class InstancesTest(QueryTest): @testbase.unsupported('mysql') # only because of "+" operator requiring "concat" in mysql (fix #475) def test_two_columns(self): sess = create_session() - (user7, user8, user9, user10) = sess.query(User).select() + (user7, user8, user9, user10) = sess.query(User).all() expected = [ (user7, 1, "Name:jack"), (user8, 3, "Name:ed"), @@ -410,7 +323,32 @@ class InstancesTest(QueryTest): l = q.instances(s.execute(), "count", "concat") assert l == expected - +class FilterByTest(QueryTest): + def test_aliased(self): + """test automatic generation of aliased joins using filter_by().""" + + sess = create_session() + + # test a basic aliasized path + q = sess.query(User).filter_by(['addresses'], email_address='jack@bean.com') + assert [User(id=7)] == q.all() + + # test two aliasized paths, one to 'orders' and the other to 'orders','items'. + # one row is returned because user 7 has order 3 and also has order 1 which has item 1 + # this tests a o2m join and a m2m join. + q = sess.query(User).filter_by(['orders'], description="order 3").filter_by(['orders', 'items'], description="item 1") + assert q.count() == 1 + assert [User(id=7)] == q.all() + + # test the control version - same joins but not aliased. rows are not returned because order 3 does not have item 1 + # addtionally by placing this test after the previous one, test that the "aliasing" step does not corrupt the + # join clauses that are cached by the relationship. + q = sess.query(User).join('orders').filter_by(description="order 3").join(['orders', 'items']).filter_by(description="item 1") + assert [] == q.all() + assert q.count() == 0 + + + if __name__ == '__main__': testbase.main()