From 67cfc86b97f6e7ea43adeeeaa5306ed7977c9e4a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 4 Jun 2007 18:14:59 +0000 Subject: [PATCH] - added all(), first(), and one() - created new test framework for query. migrating old test/orm/mapper.py tests over to new query.py --- lib/sqlalchemy/orm/query.py | 37 ++++- test/orm/alltests.py | 1 + test/orm/mapper.py | 149 -------------------- test/orm/query.py | 262 ++++++++++++++++++++++++++++++++++++ 4 files changed, 295 insertions(+), 154 deletions(-) create mode 100644 test/orm/query.py diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index f2e87b7dea..690d8b2f41 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -535,8 +535,13 @@ class Query(object): else: [keys,p] = self._locate_prop(prop, start=start) clause = self._from_obj[-1] - - currenttables = sql_util.TableFinder(self._from_obj, include_aliases=True) + + currenttables = [clause] + class FindJoinedTables(sql.NoColumnVisitor): + def visit_join(self, join): + currenttables.append(join.left) + currenttables.append(join.right) + FindJoinedTables().traverse(clause) mapper = start for key in keys: @@ -823,15 +828,19 @@ class Query(object): new._distinct = True return new - def list(self): + def all(self): """Return the results represented by this ``Query`` as a list. This results in an execution of the underlying query. """ - return list(self) + + def list(self): + """deprecated. use all()""" - def scalar(self): + return list(self) + + def first(self): """Return the first result of this ``Query``. This results in an execution of the underlying query. @@ -840,6 +849,24 @@ class Query(object): return self[0] else: return self._col_aggregate(self._col, self._func) + + def scalar(self): + """deprecated. use first()""" + return self.first() + + def one(self): + """Return the first result of this ``Query``, raising an exception if more than one row exists. + + This results in an execution of the underlying query. + """ + ret = list(self[0:2]) + + if len(ret) == 1: + return ret[0] + elif len(ret) == 0: + raise exceptions.InvalidRequestError('No rows returned for one()') + else: + raise exceptions.InvalidRequestError('Multiple rows returned for one()') def __iter__(self): return iter(self.select_whereclause()) diff --git a/test/orm/alltests.py b/test/orm/alltests.py index bdcca39979..395b402bdb 100644 --- a/test/orm/alltests.py +++ b/test/orm/alltests.py @@ -6,6 +6,7 @@ import inheritance.alltests as inheritance def suite(): modules_to_test = ( 'orm.attributes', + 'orm.query', 'orm.mapper', 'orm.generative', 'orm.lazytest1', diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 34d0c81e47..558fa62809 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -22,36 +22,6 @@ class MapperSuperTest(AssertMixin): pass class MapperTest(MapperSuperTest): - # TODO: MapperTest has grown much larger than it originally was and needs - # to be broken up among various functions, including querying, session operations, - # mapper configurational issues - def testget(self): - s = create_session() - mapper(User, users) - self.assert_(s.get(User, 19) is None) - u = s.get(User, 7) - u2 = s.get(User, 7) - self.assert_(u is u2) - s.clear() - u2 = s.get(User, 7) - self.assert_(u is not u2) - - def testunicodeget(self): - """test that Query.get properly sets up the type for the bind parameter. using unicode would normally fail - on postgres, mysql and oracle unless it is converted to an encoded string""" - metadata = BoundMetaData(db) - table = Table('foo', metadata, - Column('id', Unicode(10), primary_key=True), - Column('data', Unicode(40))) - try: - table.create() - class LocalFoo(object):pass - mapper(LocalFoo, table) - crit = 'petit voix m\xe2\x80\x99a '.decode('utf-8') - print repr(crit) - create_session().query(LocalFoo).get(crit) - finally: - table.drop() def testpropconflict(self): """test that a backref created against an existing mapper with a property name @@ -259,13 +229,6 @@ class MapperTest(MapperSuperTest): }).compile() self.assert_(User.addresses.property is m.props['addresses']) - def testquery(self): - """test a basic Query.select() operation.""" - mapper(User, users) - l = create_session().query(User).select() - self.assert_result(l, User, *user_result) - l = create_session().query(User).select(users.c.user_name.endswith('ed')) - self.assert_result(l, User, *user_result[1:3]) def testrecursiveselectby(self): """test that no endless loop occurs when traversing for select_by""" @@ -300,118 +263,6 @@ class MapperTest(MapperSuperTest): l = q.select() self.assert_result(l, User, *result) - def testwithparent(self): - """test the with_parent()) method and one-to-many relationships""" - - m = mapper(User, users, properties={ - 'orders':relation(mapper(Order, orders, properties={ - 'items':relation(mapper(Item, orderitems)) - })) - }) - - sess = create_session() - q = sess.query(m) - u1 = q.get_by(user_name='jack') - - # test auto-lookup of property - o = sess.query(Order).with_parent(u1).list() - self.assert_result(o, Order, *user_all_result[0]['orders'][1]) - - # test with explicit property - o = sess.query(Order).with_parent(u1, property='orders').list() - self.assert_result(o, Order, *user_all_result[0]['orders'][1]) - - # test static method - o = Query.query_from_parent(u1, property='orders', session=sess).list() - self.assert_result(o, Order, *user_all_result[0]['orders'][1]) - - # test generative criterion - o = sess.query(Order).with_parent(u1).select_by(orders.c.order_id>2) - self.assert_result(o, Order, *user_all_result[0]['orders'][1][1:]) - - try: - q = sess.query(Item).with_parent(u1) - assert False - except exceptions.InvalidRequestError, e: - assert str(e) == "Could not locate a property which relates instances of class 'Item' to instances of class 'User'" - - def testwithparentm2m(self): - """test the with_parent() method and many-to-many relationships""" - - m = mapper(Item, orderitems, properties = { - 'keywords' : relation(mapper(Keyword, keywords), itemkeywords) - }) - sess = create_session() - i1 = sess.query(Item).get_by(item_id=2) - k = sess.query(Keyword).with_parent(i1) - self.assert_result(k, Keyword, *item_keyword_result[1]['keywords'][1]) - - - def testautojoin(self): - """test functions derived from Query's _join_to function.""" - - m = mapper(User, users, properties={ - 'orders':relation(mapper(Order, orders, properties={ - 'items':relation(mapper(Item, orderitems)) - })) - }) - - sess = create_session() - q = sess.query(m) - - l = q.filter(orderitems.c.item_name=='item 4').join(['orders', 'items']).list() - self.assert_result(l, User, user_result[0]) - - l = q.select_by(item_name='item 4') - self.assert_result(l, User, user_result[0]) - - l = q.filter(orderitems.c.item_name=='item 4').join('item_name').list() - self.assert_result(l, User, user_result[0]) - - l = q.filter(orderitems.c.item_name=='item 4').join('items').list() - self.assert_result(l, User, user_result[0]) - - # test comparing to an object instance - item = sess.query(Item).get_by(item_name='item 4') - - l = sess.query(Order).select_by(items=item) - self.assert_result(l, Order, user_all_result[0]['orders'][1][1]) - - l = q.select_by(items=item) - self.assert_result(l, User, user_result[0]) - - # TODO: this works differently from: - #q = sess.query(User).join(['orders', 'items']).select_by(order_id=3) - # because select_by() doesnt respect query._joinpoint, whereas filter_by does - q = sess.query(User).join(['orders', 'items']).filter_by(order_id=3).list() - self.assert_result(l, User, user_result[0]) - - try: - # this should raise AttributeError - l = q.select_by(items=5) - assert False - except AttributeError: - assert True - - def testautojoinm2m(self): - """test functions derived from Query's _join_to function.""" - - m = mapper(Order, orders, properties = { - 'items' : relation(mapper(Item, orderitems, properties = { - 'keywords' : relation(mapper(Keyword, keywords), itemkeywords) - })) - }) - - sess = create_session() - q = sess.query(m) - - l = q.filter(keywords.c.name=='square').join(['items', 'keywords']).list() - self.assert_result(l, Order, order_result[1]) - - # test comparing to an object instance - item = sess.query(Item).selectfirst() - l = sess.query(Item).select_by(keywords=item.keywords[0]) - assert item == l[0] def testcustomjoin(self): """test that the from_obj parameter to query.select() can be used diff --git a/test/orm/query.py b/test/orm/query.py new file mode 100644 index 0000000000..8d3f5e67d4 --- /dev/null +++ b/test/orm/query.py @@ -0,0 +1,262 @@ +from sqlalchemy import * +from sqlalchemy.orm import * +import testbase + +class QueryTest(testbase.ORMTest): + keep_mappers = True + keep_data = True + + def setUpAll(self): + super(QueryTest, self).setUpAll() + self.install_fixture_data() + self.setup_mappers() + + def define_tables(self, metadata): + global users, orders, addresses, items, order_items, item_keywords, keywords + + users = Table('users', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(30), nullable=False)) + + orders = Table('orders', metadata, + Column('id', Integer, primary_key=True), + Column('user_id', None, ForeignKey('users.id')), + Column('address_id', None, ForeignKey('addresses.id')), + Column('description', String(30)), + Column('isopen', Integer) + ) + + addresses = Table('addresses', metadata, + Column('id', Integer, primary_key=True), + Column('user_id', None, ForeignKey('users.id')), + Column('email_address', String(50), nullable=False)) + + items = Table('items', metadata, + Column('id', Integer, primary_key=True), + Column('description', String(30), nullable=False) + ) + + order_items = Table('order_items', metadata, + Column('item_id', None, ForeignKey('items.id')), + Column('order_id', None, ForeignKey('orders.id'))) + + item_keywords = Table('item_keywords', metadata, + Column('item_id', None, ForeignKey('items.id')), + Column('keyword_id', None, ForeignKey('keywords.id'))) + + keywords = Table('keywords', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(30), nullable=False) + ) + + def install_fixture_data(self): + users.insert().execute( + dict(id = 7, name = 'jack'), + dict(id = 8, name = 'ed'), + dict(id = 9, name = 'fred'), + dict(id = 10, name = 'chuck'), + + ) + addresses.insert().execute( + dict(id = 1, user_id = 7, email_address = "jack@bean.com"), + dict(id = 2, user_id = 8, email_address = "ed@wood.com"), + dict(id = 3, user_id = 8, email_address = "ed@bettyboop.com"), + dict(id = 4, user_id = 8, email_address = "ed@lala.com"), + dict(id = 5, user_id = 9, email_address = "fred@fred.com"), + ) + orders.insert().execute( + dict(id = 1, user_id = 7, description = 'order 1', isopen=0, address_id=1), + dict(id = 2, user_id = 9, description = 'order 2', isopen=0, address_id=4), + dict(id = 3, user_id = 7, description = 'order 3', isopen=1, address_id=1), + dict(id = 4, user_id = 9, description = 'order 4', isopen=1, address_id=4), + dict(id = 5, user_id = 7, description = 'order 5', isopen=0, address_id=1) + ) + items.insert().execute( + dict(id=1, description='item 1'), + dict(id=2, description='item 2'), + dict(id=3, description='item 3'), + dict(id=4, description='item 4'), + dict(id=5, description='item 5'), + ) + order_items.insert().execute( + dict(item_id=1, order_id=1), + dict(item_id=2, order_id=1), + dict(item_id=3, order_id=1), + + dict(item_id=1, order_id=2), + dict(item_id=2, order_id=2), + dict(item_id=3, order_id=2), + dict(item_id=2, order_id=2), + + dict(item_id=3, order_id=3), + dict(item_id=4, order_id=3), + dict(item_id=5, order_id=3), + + dict(item_id=5, order_id=4), + dict(item_id=1, order_id=4), + + dict(item_id=5, order_id=5), + ) + keywords.insert().execute( + dict(id=1, name='blue'), + dict(id=2, name='red'), + dict(id=3, name='green'), + dict(id=4, name='big'), + dict(id=5, name='small'), + dict(id=6, name='round'), + dict(id=7, name='square') + ) + + # this many-to-many table has the keywords inserted + # in primary key order, to appease the unit tests. + # this is because postgres, oracle, and sqlite all support + # true insert-order row id, but of course our pal MySQL does not, + # so the best it can do is order by, well something, so there you go. + item_keywords.insert().execute( + dict(keyword_id=2, item_id=1), + dict(keyword_id=2, item_id=2), + dict(keyword_id=4, item_id=1), + dict(keyword_id=6, item_id=1), + dict(keyword_id=5, item_id=2), + dict(keyword_id=3, item_id=3), + dict(keyword_id=4, item_id=3), + dict(keyword_id=7, item_id=2), + dict(keyword_id=6, item_id=3) + ) + + def setup_mappers(self): + global User, Order, Item, Keyword, Address, Base + + class Base(object): + def __init__(self, **kwargs): + for k in kwargs: + setattr(self, k, kwargs[k]) + def __eq__(self, other): + for attr in dir(self): + if attr[0] == '_': + continue + value = getattr(self, attr) + if isinstance(value, list): + for (us, them) in zip(value, getattr(other, attr)): + if us != them: + return False + else: + return True + else: + if value is not None: + return value == getattr(other, attr) + + class User(Base):pass + class Order(Base):pass + class Item(Base):pass + class Keyword(Base):pass + class Address(Base):pass + + mapper(User, users, properties={ + 'orders':relation(Order, backref='user'), # o2m, m2o + }) + mapper(Address, addresses) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items), #m2m + 'address':relation(Address), # m2o + }) + mapper(Item, items, properties={ + 'keywords':relation(Keyword, secondary=item_keywords) #m2m + }) + mapper(Keyword, keywords) + +class GetTest(QueryTest): + def test_get(self): + s = create_session() + assert s.query(User).get(19) is None + u = s.query(User).get(7) + u2 = s.query(User).get(7) + assert u is u2 + s.clear() + u2 = s.query(User).get(7) + assert u is not u2 + + def test_unicode(self): + """test that Query.get properly sets up the type for the bind parameter. using unicode would normally fail + on postgres, mysql and oracle unless it is converted to an encoded string""" + + table = Table('unicode_data', users.metadata, + Column('id', Unicode(40), primary_key=True), + Column('data', Unicode(40))) + table.create() + ustring = 'petit voix m\xe2\x80\x99a '.decode('utf-8') + table.insert().execute(id=ustring, data=ustring) + class LocalFoo(Base):pass + mapper(LocalFoo, table) + assert create_session().query(LocalFoo).get(ustring) == LocalFoo(id=ustring, data=ustring) + +class FilterTest(QueryTest): + def test_basic(self): + assert create_session().query(User).all() == [User(id=7), User(id=8), User(id=9),User(id=10)] + + def test_onefilter(self): + assert create_session().query(User).filter(users.c.name.endswith('ed')).all() == [User(id=8), User(id=9)] + +class ParentTest(QueryTest): + def test_o2m(self): + sess = create_session() + q = sess.query(User) + + u1 = q.filter_by(name='jack').one() + + # test auto-lookup of property + o = sess.query(Order).with_parent(u1).all() + assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o + + # test with explicit property + o = sess.query(Order).with_parent(u1, property='orders').all() + assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o + + # test static method + o = Query.query_from_parent(u1, property='orders', session=sess).all() + assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o + + # test generative criterion + o = sess.query(Order).with_parent(u1).filter(orders.c.id>2).all() + assert [Order(description="order 3"), Order(description="order 5")] == o + + def test_noparent(self): + sess = create_session() + q = sess.query(User) + + u1 = q.filter_by(name='jack').one() + + try: + q = sess.query(Item).with_parent(u1) + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Could not locate a property which relates instances of class 'Item' to instances of class 'User'" + + def test_m2m(self): + sess = create_session() + i1 = sess.query(Item).filter_by(id=2).one() + k = sess.query(Keyword).with_parent(i1).all() + assert [Keyword(name='red'), Keyword(name='small'), Keyword(name='square')] == k + + +class JoinTest(QueryTest): + def test_overlapping_paths(self): + result = create_session().query(User).join(['orders', 'items']).filter_by(id=3).join(['orders','address']).filter_by(id=1).all() + assert [User(id=7, name='jack')] == result + + def test_overlap_with_aliases(self): + oalias = orders.alias('oalias') + + result = create_session().query(User).select_from(users.join(oalias)).filter(oalias.c.description.in_("order 1", "order 2", "order 3")).join(['orders', 'items']).all() + assert [User(id=7, name='jack'), User(id=9, name='fred')] == result + + result = create_session().query(User).select_from(users.join(oalias)).filter(oalias.c.description.in_("order 1", "order 2", "order 3")).join(['orders', 'items']).filter_by(id=4).all() + assert [User(id=7, name='jack')] == result + + + + +if __name__ == '__main__': + testbase.main() + + -- 2.47.3