]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added all(), first(), and one()
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 4 Jun 2007 18:14:59 +0000 (18:14 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 4 Jun 2007 18:14:59 +0000 (18:14 +0000)
- created new test framework for query.  migrating old test/orm/mapper.py tests over to new query.py

lib/sqlalchemy/orm/query.py
test/orm/alltests.py
test/orm/mapper.py
test/orm/query.py [new file with mode: 0644]

index f2e87b7deaa7ffc125f283372bcfb6d5ad8bdb60..690d8b2f416b5016a8e6ab6c41eaa4f6e7dafbbb 100644 (file)
@@ -535,8 +535,13 @@ class Query(object):
         else:
             [keys,p] = self._locate_prop(prop, start=start)
         clause = self._from_obj[-1]
-
-        currenttables = sql_util.TableFinder(self._from_obj, include_aliases=True)
+        
+        currenttables = [clause]
+        class FindJoinedTables(sql.NoColumnVisitor):
+            def visit_join(self, join):
+                currenttables.append(join.left)
+                currenttables.append(join.right)
+        FindJoinedTables().traverse(clause)
             
         mapper = start
         for key in keys:
@@ -823,15 +828,19 @@ class Query(object):
         new._distinct = True
         return new
 
-    def list(self):
+    def all(self):
         """Return the results represented by this ``Query`` as a list.
 
         This results in an execution of the underlying query.
         """
-
         return list(self)
+        
+    def list(self):
+        """deprecated.  use all()"""
 
-    def scalar(self):
+        return list(self)
+    
+    def first(self):
         """Return the first result of this ``Query``.
 
         This results in an execution of the underlying query.
@@ -840,6 +849,24 @@ class Query(object):
             return self[0]
         else:
             return self._col_aggregate(self._col, self._func)
+
+    def scalar(self):
+        """deprecated.  use first()"""
+        return self.first()
+
+    def one(self):
+        """Return the first result of this ``Query``, raising an exception if more than one row exists.
+
+        This results in an execution of the underlying query.
+        """
+        ret = list(self[0:2])
+        
+        if len(ret) == 1:
+            return ret[0]
+        elif len(ret) == 0:
+            raise exceptions.InvalidRequestError('No rows returned for one()')
+        else:
+            raise exceptions.InvalidRequestError('Multiple rows returned for one()')
     
     def __iter__(self):
         return iter(self.select_whereclause())
index bdcca39979bb90948647da6148467128b2010802..395b402bdb2c31f9f46ab67e128cce5ef68c8d5c 100644 (file)
@@ -6,6 +6,7 @@ import inheritance.alltests as inheritance
 def suite():
     modules_to_test = (
        'orm.attributes',
+           'orm.query',
         'orm.mapper',
         'orm.generative',
         'orm.lazytest1',
index 34d0c81e473856660f067d11a8e4ea345732ee54..558fa62809a05c4c871ccafcd068f30919bf6be7 100644 (file)
@@ -22,36 +22,6 @@ class MapperSuperTest(AssertMixin):
         pass
     
 class MapperTest(MapperSuperTest):
-    # TODO: MapperTest has grown much larger than it originally was and needs
-    # to be broken up among various functions, including querying, session operations,
-    # mapper configurational issues
-    def testget(self):
-        s = create_session()
-        mapper(User, users)
-        self.assert_(s.get(User, 19) is None)
-        u = s.get(User, 7)
-        u2 = s.get(User, 7)
-        self.assert_(u is u2)
-        s.clear()
-        u2 = s.get(User, 7)
-        self.assert_(u is not u2)
-
-    def testunicodeget(self):
-        """test that Query.get properly sets up the type for the bind parameter.  using unicode would normally fail 
-        on postgres, mysql and oracle unless it is converted to an encoded string"""
-        metadata = BoundMetaData(db)
-        table = Table('foo', metadata, 
-            Column('id', Unicode(10), primary_key=True),
-            Column('data', Unicode(40)))
-        try:
-            table.create()
-            class LocalFoo(object):pass
-            mapper(LocalFoo, table)
-            crit = 'petit voix m\xe2\x80\x99a '.decode('utf-8')
-            print repr(crit)
-            create_session().query(LocalFoo).get(crit)
-        finally:
-            table.drop()
 
     def testpropconflict(self):
         """test that a backref created against an existing mapper with a property name
@@ -259,13 +229,6 @@ class MapperTest(MapperSuperTest):
         }).compile()
         self.assert_(User.addresses.property is m.props['addresses'])
         
-    def testquery(self):
-        """test a basic Query.select() operation."""
-        mapper(User, users)
-        l = create_session().query(User).select()
-        self.assert_result(l, User, *user_result)
-        l = create_session().query(User).select(users.c.user_name.endswith('ed'))
-        self.assert_result(l, User, *user_result[1:3])
 
     def testrecursiveselectby(self):
         """test that no endless loop occurs when traversing for select_by"""
@@ -300,118 +263,6 @@ class MapperTest(MapperSuperTest):
         l = q.select()
         self.assert_result(l, User, *result)
 
-    def testwithparent(self):
-        """test the with_parent()) method and one-to-many relationships"""
-        
-        m = mapper(User, users, properties={
-            'orders':relation(mapper(Order, orders, properties={
-                'items':relation(mapper(Item, orderitems))
-            }))
-        })
-
-        sess = create_session()
-        q = sess.query(m)
-        u1 = q.get_by(user_name='jack')
-
-        # test auto-lookup of property
-        o = sess.query(Order).with_parent(u1).list()
-        self.assert_result(o, Order, *user_all_result[0]['orders'][1])
-
-        # test with explicit property
-        o = sess.query(Order).with_parent(u1, property='orders').list()
-        self.assert_result(o, Order, *user_all_result[0]['orders'][1])
-
-        # test static method
-        o = Query.query_from_parent(u1, property='orders', session=sess).list()
-        self.assert_result(o, Order, *user_all_result[0]['orders'][1])
-
-        # test generative criterion
-        o = sess.query(Order).with_parent(u1).select_by(orders.c.order_id>2)
-        self.assert_result(o, Order, *user_all_result[0]['orders'][1][1:])
-
-        try:
-            q = sess.query(Item).with_parent(u1)
-            assert False
-        except exceptions.InvalidRequestError, e:
-            assert str(e) == "Could not locate a property which relates instances of class 'Item' to instances of class 'User'"
-            
-    def testwithparentm2m(self):
-        """test the with_parent() method and many-to-many relationships"""
-        
-        m = mapper(Item, orderitems, properties = {
-                'keywords' : relation(mapper(Keyword, keywords), itemkeywords)
-        })
-        sess = create_session()
-        i1 = sess.query(Item).get_by(item_id=2)
-        k = sess.query(Keyword).with_parent(i1)
-        self.assert_result(k, Keyword, *item_keyword_result[1]['keywords'][1])
-
-        
-    def testautojoin(self):
-        """test functions derived from Query's _join_to function."""
-        
-        m = mapper(User, users, properties={
-            'orders':relation(mapper(Order, orders, properties={
-                'items':relation(mapper(Item, orderitems))
-            }))
-        })
-
-        sess = create_session()
-        q = sess.query(m)
-
-        l = q.filter(orderitems.c.item_name=='item 4').join(['orders', 'items']).list()
-        self.assert_result(l, User, user_result[0])
-        
-        l = q.select_by(item_name='item 4')
-        self.assert_result(l, User, user_result[0])
-
-        l = q.filter(orderitems.c.item_name=='item 4').join('item_name').list()
-        self.assert_result(l, User, user_result[0])
-
-        l = q.filter(orderitems.c.item_name=='item 4').join('items').list()
-        self.assert_result(l, User, user_result[0])
-
-        # test comparing to an object instance
-        item = sess.query(Item).get_by(item_name='item 4')
-
-        l = sess.query(Order).select_by(items=item)
-        self.assert_result(l, Order, user_all_result[0]['orders'][1][1])
-
-        l = q.select_by(items=item)
-        self.assert_result(l, User, user_result[0])
-        
-        # TODO: this works differently from:
-        #q = sess.query(User).join(['orders', 'items']).select_by(order_id=3)
-        # because select_by() doesnt respect query._joinpoint, whereas filter_by does
-        q = sess.query(User).join(['orders', 'items']).filter_by(order_id=3).list()
-        self.assert_result(l, User, user_result[0])
-        
-        try:
-            # this should raise AttributeError
-            l = q.select_by(items=5)
-            assert False
-        except AttributeError:
-            assert True
-        
-    def testautojoinm2m(self):
-        """test functions derived from Query's _join_to function."""
-        
-        m = mapper(Order, orders, properties = {
-            'items' : relation(mapper(Item, orderitems, properties = {
-                'keywords' : relation(mapper(Keyword, keywords), itemkeywords)
-            }))
-        })
-        
-        sess = create_session()
-        q = sess.query(m)
-
-        l = q.filter(keywords.c.name=='square').join(['items', 'keywords']).list()
-        self.assert_result(l, Order, order_result[1])
-
-        # test comparing to an object instance
-        item = sess.query(Item).selectfirst()
-        l = sess.query(Item).select_by(keywords=item.keywords[0])
-        assert item == l[0]
         
     def testcustomjoin(self):
         """test that the from_obj parameter to query.select() can be used
diff --git a/test/orm/query.py b/test/orm/query.py
new file mode 100644 (file)
index 0000000..8d3f5e6
--- /dev/null
@@ -0,0 +1,262 @@
+from sqlalchemy import *
+from sqlalchemy.orm import *
+import testbase
+
+class QueryTest(testbase.ORMTest):
+    keep_mappers = True
+    keep_data = True
+    
+    def setUpAll(self):
+        super(QueryTest, self).setUpAll()
+        self.install_fixture_data()
+        self.setup_mappers()
+        
+    def define_tables(self, metadata):
+        global users, orders, addresses, items, order_items, item_keywords, keywords
+        
+        users = Table('users', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('name', String(30), nullable=False))
+
+        orders = Table('orders', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('user_id', None, ForeignKey('users.id')),
+            Column('address_id', None, ForeignKey('addresses.id')),
+            Column('description', String(30)),
+            Column('isopen', Integer)
+            )
+
+        addresses = Table('addresses', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('user_id', None, ForeignKey('users.id')),
+            Column('email_address', String(50), nullable=False))
+            
+        items = Table('items', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('description', String(30), nullable=False)
+            )
+
+        order_items = Table('order_items', metadata,
+            Column('item_id', None, ForeignKey('items.id')),
+            Column('order_id', None, ForeignKey('orders.id')))
+
+        item_keywords = Table('item_keywords', metadata, 
+            Column('item_id', None, ForeignKey('items.id')),
+            Column('keyword_id', None, ForeignKey('keywords.id')))
+
+        keywords = Table('keywords', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('name', String(30), nullable=False)
+            )
+        
+    def install_fixture_data(self):
+        users.insert().execute(
+            dict(id = 7, name = 'jack'),
+            dict(id = 8, name = 'ed'),
+            dict(id = 9, name = 'fred'),
+            dict(id = 10, name = 'chuck'),
+            
+        )
+        addresses.insert().execute(
+            dict(id = 1, user_id = 7, email_address = "jack@bean.com"),
+            dict(id = 2, user_id = 8, email_address = "ed@wood.com"),
+            dict(id = 3, user_id = 8, email_address = "ed@bettyboop.com"),
+            dict(id = 4, user_id = 8, email_address = "ed@lala.com"),
+            dict(id = 5, user_id = 9, email_address = "fred@fred.com"),
+        )
+        orders.insert().execute(
+            dict(id = 1, user_id = 7, description = 'order 1', isopen=0, address_id=1),
+            dict(id = 2, user_id = 9, description = 'order 2', isopen=0, address_id=4),
+            dict(id = 3, user_id = 7, description = 'order 3', isopen=1, address_id=1),
+            dict(id = 4, user_id = 9, description = 'order 4', isopen=1, address_id=4),
+            dict(id = 5, user_id = 7, description = 'order 5', isopen=0, address_id=1)
+        )
+        items.insert().execute(
+            dict(id=1, description='item 1'),
+            dict(id=2, description='item 2'),
+            dict(id=3, description='item 3'),
+            dict(id=4, description='item 4'),
+            dict(id=5, description='item 5'),
+        )
+        order_items.insert().execute(
+            dict(item_id=1, order_id=1),
+            dict(item_id=2, order_id=1),
+            dict(item_id=3, order_id=1),
+
+            dict(item_id=1, order_id=2),
+            dict(item_id=2, order_id=2),
+            dict(item_id=3, order_id=2),
+            dict(item_id=2, order_id=2),
+
+            dict(item_id=3, order_id=3),
+            dict(item_id=4, order_id=3),
+            dict(item_id=5, order_id=3),
+            
+            dict(item_id=5, order_id=4),
+            dict(item_id=1, order_id=4),
+            
+            dict(item_id=5, order_id=5),
+        )
+        keywords.insert().execute(
+            dict(id=1, name='blue'),
+            dict(id=2, name='red'),
+            dict(id=3, name='green'),
+            dict(id=4, name='big'),
+            dict(id=5, name='small'),
+            dict(id=6, name='round'),
+            dict(id=7, name='square')
+        )
+
+        # this many-to-many table has the keywords inserted
+        # in primary key order, to appease the unit tests.
+        # this is because postgres, oracle, and sqlite all support 
+        # true insert-order row id, but of course our pal MySQL does not,
+        # so the best it can do is order by, well something, so there you go.
+        item_keywords.insert().execute(
+            dict(keyword_id=2, item_id=1),
+            dict(keyword_id=2, item_id=2),
+            dict(keyword_id=4, item_id=1),
+            dict(keyword_id=6, item_id=1),
+            dict(keyword_id=5, item_id=2),
+            dict(keyword_id=3, item_id=3),
+            dict(keyword_id=4, item_id=3),
+            dict(keyword_id=7, item_id=2),
+            dict(keyword_id=6, item_id=3)
+        )
+        
+    def setup_mappers(self):
+        global User, Order, Item, Keyword, Address, Base
+        
+        class Base(object):
+            def __init__(self, **kwargs):
+                for k in kwargs:
+                    setattr(self, k, kwargs[k])
+            def __eq__(self, other):
+                for attr in dir(self):
+                    if attr[0] == '_':
+                        continue
+                    value = getattr(self, attr)
+                    if isinstance(value, list):
+                        for (us, them) in zip(value, getattr(other, attr)):
+                            if us != them:
+                                return False
+                        else:
+                            return True
+                    else:
+                        if value is not None:
+                            return value == getattr(other, attr)
+                    
+        class User(Base):pass
+        class Order(Base):pass
+        class Item(Base):pass
+        class Keyword(Base):pass
+        class Address(Base):pass
+
+        mapper(User, users, properties={
+            'orders':relation(Order, backref='user'), # o2m, m2o
+        })
+        mapper(Address, addresses)
+        mapper(Order, orders, properties={
+            'items':relation(Item, secondary=order_items),  #m2m
+            'address':relation(Address),  # m2o
+        })
+        mapper(Item, items, properties={
+            'keywords':relation(Keyword, secondary=item_keywords) #m2m
+        })
+        mapper(Keyword, keywords)
+
+class GetTest(QueryTest):
+    def test_get(self):
+        s = create_session()
+        assert s.query(User).get(19) is None
+        u = s.query(User).get(7)
+        u2 = s.query(User).get(7)
+        assert u is u2
+        s.clear()
+        u2 = s.query(User).get(7)
+        assert u is not u2
+
+    def test_unicode(self):
+        """test that Query.get properly sets up the type for the bind parameter.  using unicode would normally fail 
+        on postgres, mysql and oracle unless it is converted to an encoded string"""
+        
+        table = Table('unicode_data', users.metadata, 
+            Column('id', Unicode(40), primary_key=True),
+            Column('data', Unicode(40)))
+        table.create()
+        ustring = 'petit voix m\xe2\x80\x99a '.decode('utf-8')
+        table.insert().execute(id=ustring, data=ustring)
+        class LocalFoo(Base):pass
+        mapper(LocalFoo, table)
+        assert create_session().query(LocalFoo).get(ustring) == LocalFoo(id=ustring, data=ustring)
+        
+class FilterTest(QueryTest):
+    def test_basic(self):
+        assert create_session().query(User).all() == [User(id=7), User(id=8), User(id=9),User(id=10)]
+
+    def test_onefilter(self):
+        assert create_session().query(User).filter(users.c.name.endswith('ed')).all() == [User(id=8), User(id=9)]
+
+class ParentTest(QueryTest):
+    def test_o2m(self):
+        sess = create_session()
+        q = sess.query(User)
+        
+        u1 = q.filter_by(name='jack').one()
+
+        # test auto-lookup of property
+        o = sess.query(Order).with_parent(u1).all()
+        assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o
+
+        # test with explicit property
+        o = sess.query(Order).with_parent(u1, property='orders').all()
+        assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o
+
+        # test static method
+        o = Query.query_from_parent(u1, property='orders', session=sess).all()
+        assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o
+
+        # test generative criterion
+        o = sess.query(Order).with_parent(u1).filter(orders.c.id>2).all()
+        assert [Order(description="order 3"), Order(description="order 5")] == o
+
+    def test_noparent(self):
+        sess = create_session()
+        q = sess.query(User)
+        
+        u1 = q.filter_by(name='jack').one()
+
+        try:
+            q = sess.query(Item).with_parent(u1)
+            assert False
+        except exceptions.InvalidRequestError, e:
+            assert str(e) == "Could not locate a property which relates instances of class 'Item' to instances of class 'User'"
+
+    def test_m2m(self):
+        sess = create_session()
+        i1 = sess.query(Item).filter_by(id=2).one()
+        k = sess.query(Keyword).with_parent(i1).all()
+        assert [Keyword(name='red'), Keyword(name='small'), Keyword(name='square')] == k
+        
+    
+class JoinTest(QueryTest):
+    def test_overlapping_paths(self):
+        result = create_session().query(User).join(['orders', 'items']).filter_by(id=3).join(['orders','address']).filter_by(id=1).all()
+        assert [User(id=7, name='jack')] == result
+    
+    def test_overlap_with_aliases(self):
+        oalias = orders.alias('oalias')
+
+        result = create_session().query(User).select_from(users.join(oalias)).filter(oalias.c.description.in_("order 1", "order 2", "order 3")).join(['orders', 'items']).all()
+        assert [User(id=7, name='jack'), User(id=9, name='fred')] == result
+        
+        result = create_session().query(User).select_from(users.join(oalias)).filter(oalias.c.description.in_("order 1", "order 2", "order 3")).join(['orders', 'items']).filter_by(id=4).all()
+        assert [User(id=7, name='jack')] == result
+
+
+
+
+if __name__ == '__main__':
+    testbase.main()
+
+