]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
adding tests, all the features present in joined eager loading
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 23 Mar 2010 16:11:20 +0000 (12:11 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 23 Mar 2010 16:11:20 +0000 (12:11 -0400)
lib/sqlalchemy/orm/strategies.py
test/orm/test_eager_relations.py
test/orm/test_subquery_relations.py

index 46cfdbe61c27168af11012a705a2a2c1c4b21f56..259e0f10c13c154f8f717b0e849f1036260a839a 100644 (file)
@@ -632,6 +632,10 @@ class LoadLazyAttribute(object):
                 return None
 
 class SubqueryLoader(AbstractRelationshipLoader):
+    def init(self):
+        super(SubqueryLoader, self).init()
+        self.join_depth = self.parent_property.join_depth
+    
     def init_class_attribute(self, mapper):
         self.parent_property.\
                 _get_strategy(LazyLoader).\
@@ -643,6 +647,27 @@ class SubqueryLoader(AbstractRelationshipLoader):
 
         if not context.query._enable_eagerloads:
             return
+
+        path = path + (self.key, )
+
+        # build up a path indicating the path from the leftmost
+        # entity to the thing we're subquery loading.
+        subq_path = context.attributes.get(('subquery_path', None), ())
+
+        subq_path = subq_path + path
+
+        reduced_path = interfaces._reduce_path(subq_path)
+
+        # check for join_depth or basic recursion,
+        # if the current path was not explicitly stated as 
+        # a desired "loaderstrategy" (i.e. via query.options())
+        if ("loaderstrategy", reduced_path) not in context.attributes:
+            if self.join_depth:
+                if len(path) / 2 > self.join_depth:
+                    return
+            else:
+                if self.mapper.base_mapper in reduced_path:
+                    return
         
         # the leftmost query we'll be joining from.
         # in the case of an end-user query with eager or subq
@@ -655,15 +680,9 @@ class SubqueryLoader(AbstractRelationshipLoader):
             
         orig_query = context.attributes[("orig_query", SubqueryLoader)]
 
-        # build up a path indicating the path from the leftmost
-        # entity to the thing we're subquery loading.
-        subq_path = context.attributes.get(('subquery_path', None), ())
-        
-        path = path + (self.key, )
         
         local_cols, remote_cols = self._local_remote_columns(self.parent_property)
         
-        subq_path = subq_path + path
         leftmost_mapper, leftmost_prop = \
                             subq_path[0], subq_path[0].get_property(subq_path[1])
         leftmost_cols, remote_cols = self._local_remote_columns(leftmost_prop)
@@ -688,7 +707,8 @@ class SubqueryLoader(AbstractRelationshipLoader):
         q._attributes = {}
         q._attributes[("orig_query", SubqueryLoader)] = orig_query
         q._set_entities(leftmost_attr)
-        q._order_by = None
+        if q._limit is None and q._offset is None:
+            q._order_by = None
         
         q._attributes[('subquery_path', None)] = subq_path
         
@@ -750,15 +770,31 @@ class SubqueryLoader(AbstractRelationshipLoader):
         
         if adapter:
             local_cols = [adapter.columns[c] for c in local_cols]
-
-        def execute(state, dict_, row):
-            collection = collections.get(
-                tuple([row[col] for col in local_cols]), 
-                ()
-            )
-            state.get_impl(self.key).\
-                    set_committed_value(state, dict_, collection)
-                
+        
+        if self.uselist:
+            def execute(state, dict_, row):
+                collection = collections.get(
+                    tuple([row[col] for col in local_cols]), 
+                    ()
+                )
+                state.get_impl(self.key).\
+                        set_committed_value(state, dict_, collection)
+        else:
+            def execute(state, dict_, row):
+                collection = collections.get(
+                    tuple([row[col] for col in local_cols]), 
+                    (None,)
+                )
+                if len(collection) > 1:
+                    util.warn(
+                        "Multiple rows returned with "
+                        "uselist=False for eagerly-loaded attribute '%s' "
+                        % self)
+                    
+                scalar = collection[0]
+                state.get_impl(self.key).\
+                        set_committed_value(state, dict_, scalar)
+            
         return execute, None
 
 class EagerLoader(AbstractRelationshipLoader):
index 0f635e905023e04a9a0a29676e43de27fc0d4028..e06aa6ff1d95f15f65c03de8e1dde4248c072fce 100644 (file)
@@ -332,7 +332,8 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
         mapper(Address, addresses)
         mapper(User, users, properties = dict(
             addresses = relationship(Address, lazy=False,
-                                 backref=sa.orm.backref('user', lazy=False), order_by=Address.id)
+                                 backref=sa.orm.backref('user', lazy=False),
+                                            order_by=Address.id)
         ))
         is_(sa.orm.class_mapper(User).get_property('addresses').lazy, False)
         is_(sa.orm.class_mapper(Address).get_property('user').lazy, False)
@@ -342,7 +343,8 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
 
     @testing.resolve_artifact_names
     def test_double(self):
-        """Eager loading with two relationships simultaneously, from the same table, using aliases."""
+        """Eager loading with two relationships simultaneously, 
+            from the same table, using aliases."""
 
         openorders = sa.alias(orders, 'openorders')
         closedorders = sa.alias(orders, 'closedorders')
@@ -395,7 +397,8 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
 
     @testing.resolve_artifact_names
     def test_double_same_mappers(self):
-        """Eager loading with two relationships simulatneously, from the same table, using aliases."""
+        """Eager loading with two relationships simulatneously, 
+        from the same table, using aliases."""
 
         mapper(Address, addresses)
         mapper(Order, orders, properties={
@@ -461,7 +464,8 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
 
     @testing.resolve_artifact_names
     def test_no_false_hits(self):
-        """Eager loaders don't interpret main table columns as part of their eager load."""
+        """Eager loaders don't interpret main table columns as 
+        part of their eager load."""
 
         mapper(User, users, properties={
             'addresses':relationship(Address, lazy=False),
@@ -476,7 +480,8 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
         # eager loaders have aliases which should not hit on those columns,
         # they should be required to locate only their aliased/fully table
         # qualified column name.
-        noeagers = create_session().query(User).from_statement("select * from users").all()
+        noeagers = create_session().query(User).\
+                        from_statement("select * from users").all()
         assert 'orders' not in noeagers[0].__dict__
         assert 'addresses' not in noeagers[0].__dict__
 
@@ -487,7 +492,8 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
 
         mapper(Item, items)
         mapper(Order, orders, properties={
-            'items':relationship(Item, secondary=order_items, lazy=False, order_by=items.c.id)
+            'items':relationship(Item, secondary=order_items, lazy=False,
+                order_by=items.c.id)
         })
         mapper(User, users, properties={
             'addresses':relationship(mapper(Address, addresses), lazy=False, order_by=addresses.c.id),
@@ -530,7 +536,9 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
 
         sess = create_session()
         q = sess.query(Item)
-        l = q.filter((Item.description=='item 2') | (Item.description=='item 5') | (Item.description=='item 3')).\
+        l = q.filter((Item.description=='item 2') | 
+                        (Item.description=='item 5') | 
+                        (Item.description=='item 3')).\
             order_by(Item.id).limit(2).all()
 
         eq_(self.static.item_keyword_result[1:3], l)
@@ -538,8 +546,10 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
     @testing.fails_on('maxdb', 'FIXME: unknown')
     @testing.resolve_artifact_names
     def test_limit_3(self):
-        """test that the ORDER BY is propagated from the inner select to the outer select, when using the
-        'wrapped' select statement resulting from the combination of eager loading and limit/offset clauses."""
+        """test that the ORDER BY is propagated from the inner 
+        select to the outer select, when using the
+        'wrapped' select statement resulting from the combination of 
+        eager loading and limit/offset clauses."""
 
         mapper(Item, items)
         mapper(Order, orders, properties = dict(
@@ -685,7 +695,8 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
     @testing.resolve_artifact_names
     def test_one_to_many_scalar(self):
         mapper(User, users, properties = dict(
-            address = relationship(mapper(Address, addresses), lazy=False, uselist=False)
+            address = relationship(mapper(Address, addresses), 
+                                    lazy=False, uselist=False)
         ))
         q = create_session().query(User)
 
@@ -765,13 +776,16 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
     @testing.resolve_artifact_names
     def test_double_with_aggregate(self):
         max_orders_by_user = sa.select([sa.func.max(orders.c.id).label('order_id')],
-                                       group_by=[orders.c.user_id]).alias('max_orders_by_user')
+                                       group_by=[orders.c.user_id]
+                                     ).alias('max_orders_by_user')
 
-        max_orders = orders.select(orders.c.id==max_orders_by_user.c.order_id).alias('max_orders')
+        max_orders = orders.select(orders.c.id==max_orders_by_user.c.order_id).\
+                                alias('max_orders')
 
         mapper(Order, orders)
         mapper(User, users, properties={
-               'orders':relationship(Order, backref='user', lazy=False, order_by=orders.c.id),
+               'orders':relationship(Order, backref='user', lazy=False,
+                                            order_by=orders.c.id),
                'max_order':relationship(
                                 mapper(Order, max_orders, non_primary=True), 
                                 lazy=False, uselist=False)
@@ -798,14 +812,16 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
 
     @testing.resolve_artifact_names 
     def test_uselist_false_warning(self):
-        """test that multiple rows received by a uselist=False raises a warning."""
+        """test that multiple rows received by a 
+        uselist=False raises a warning."""
         
         mapper(User, users, properties={
             'order':relationship(Order, uselist=False)
         })
         mapper(Order, orders)
         s = create_session()
-        assert_raises(sa.exc.SAWarning, s.query(User).options(eagerload(User.order)).all)
+        assert_raises(sa.exc.SAWarning,
+                s.query(User).options(eagerload(User.order)).all)
         
     @testing.resolve_artifact_names
     def test_wide(self):
index efa589f3c65522ddcdbee52644d43d47563347e8..6385bbb5325e1d375f1cba809966ef21bfaf8e0e 100644 (file)
@@ -211,29 +211,8 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
                 options.append(callables[i](User.orders, Order.items))
             if k in callables:
                 options.append(callables[k](User.orders, Order.items, Item.keywords))
-                
-            sess = create_session()
-            def go():
-                eq_(
-                    sess.query(User).options(*options).order_by(User.id).all(),
-                    self.static.user_item_keyword_result
-                )
-            self.assert_sql_count(testing.db, go, count)
-
-            sess = create_session()
-            eq_(
-                sess.query(User).filter(User.name=='fred').
-                        options(*options).order_by(User.id).all(),
-                self.static.user_item_keyword_result[2:3]
-            )
 
-            sess = create_session()
-            eq_(
-                sess.query(User).join(User.orders).
-                        filter(Order.id==3).\
-                        options(*options).order_by(User.id).all(),
-                self.static.user_item_keyword_result[0:1]
-            )
+            self._do_query_tests(options, count)
 
     @testing.resolve_artifact_names
     def _do_mapper_test(self, configs):
@@ -241,7 +220,6 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
             'lazyload':'select',
             'eagerload':'joined',
             'subqueryload':'subquery',
-            
         }
 
         for o, i, k, count in configs:
@@ -259,34 +237,281 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
                                             order_by=keywords.c.id)
             })
             mapper(Keyword, keywords)
-
-            sess = create_session()
-            def go():
-                eq_(
-                    sess.query(User).order_by(User.id).all(),
-                    self.static.user_item_keyword_result
-                )
+            
             try:
-                self.assert_sql_count(testing.db, go, count)
-
-                eq_(
-                    sess.query(User).filter(User.name=='fred').
-                            order_by(User.id).all(),
-                    self.static.user_item_keyword_result[2:3]
-                )
-
-                sess = create_session()
-                eq_(
-                    sess.query(User).join(User.orders).
-                            filter(Order.id==3).\
-                            order_by(User.id).all(),
-                    self.static.user_item_keyword_result[0:1]
-                )
-
+                self._do_query_tests([], count)
             finally:
                 clear_mappers()
+    
+    @testing.resolve_artifact_names
+    def _do_query_tests(self, opts, count):
+        sess = create_session()
+        def go():
+            eq_(
+                sess.query(User).options(*opts).order_by(User.id).all(),
+                self.static.user_item_keyword_result
+            )
+        self.assert_sql_count(testing.db, go, count)
+
+        eq_(
+            sess.query(User).options(*opts).filter(User.name=='fred').
+                    order_by(User.id).all(),
+            self.static.user_item_keyword_result[2:3]
+        )
+
+        sess = create_session()
+        eq_(
+            sess.query(User).options(*opts).join(User.orders).
+                    filter(Order.id==3).\
+                    order_by(User.id).all(),
+            self.static.user_item_keyword_result[0:1]
+        )
         
-        
+    
+    @testing.resolve_artifact_names
+    def test_cyclical(self):
+        """A circular eager relationship breaks the cycle with a lazy loader"""
+
+        mapper(Address, addresses)
+        mapper(User, users, properties = dict(
+            addresses = relationship(Address, lazy='subquery',
+                                 backref=sa.orm.backref('user', lazy='subquery'),
+                                            order_by=Address.id)
+        ))
+        is_(sa.orm.class_mapper(User).get_property('addresses').lazy, 'subquery')
+        is_(sa.orm.class_mapper(Address).get_property('user').lazy, 'subquery')
+
+        sess = create_session()
+        eq_(self.static.user_address_result, sess.query(User).order_by(User.id).all())
+
+    @testing.resolve_artifact_names
+    def test_double(self):
+        """Eager loading with two relationships simultaneously, 
+            from the same table, using aliases."""
+
+        openorders = sa.alias(orders, 'openorders')
+        closedorders = sa.alias(orders, 'closedorders')
+
+        mapper(Address, addresses)
+        mapper(Order, orders)
+
+        open_mapper = mapper(Order, openorders, non_primary=True)
+        closed_mapper = mapper(Order, closedorders, non_primary=True)
+
+        mapper(User, users, properties = dict(
+            addresses = relationship(Address, lazy='subquery',
+                                        order_by=addresses.c.id),
+            open_orders = relationship(
+                open_mapper,
+                primaryjoin=sa.and_(openorders.c.isopen == 1,
+                                 users.c.id==openorders.c.user_id),
+                lazy='subquery', order_by=openorders.c.id),
+            closed_orders = relationship(
+                closed_mapper,
+                primaryjoin=sa.and_(closedorders.c.isopen == 0,
+                                 users.c.id==closedorders.c.user_id),
+                lazy='subquery', order_by=closedorders.c.id)))
+
+        q = create_session().query(User).order_by(User.id)
+
+        def go():
+            eq_([
+                User(
+                    id=7,
+                    addresses=[Address(id=1)],
+                    open_orders = [Order(id=3)],
+                    closed_orders = [Order(id=1), Order(id=5)]
+                ),
+                User(
+                    id=8,
+                    addresses=[Address(id=2), Address(id=3), Address(id=4)],
+                    open_orders = [],
+                    closed_orders = []
+                ),
+                User(
+                    id=9,
+                    addresses=[Address(id=5)],
+                    open_orders = [Order(id=4)],
+                    closed_orders = [Order(id=2)]
+                ),
+                User(id=10)
+
+            ], q.all())
+        self.assert_sql_count(testing.db, go, 4)
+
+    @testing.resolve_artifact_names
+    def test_double_same_mappers(self):
+        """Eager loading with two relationships simulatneously, 
+        from the same table, using aliases."""
+
+        mapper(Address, addresses)
+        mapper(Order, orders, properties={
+            'items': relationship(Item, secondary=order_items, lazy='subquery',
+                              order_by=items.c.id)})
+        mapper(Item, items)
+        mapper(User, users, properties=dict(
+            addresses=relationship(Address, lazy='subquery', order_by=addresses.c.id),
+            open_orders=relationship(
+                Order,
+                primaryjoin=sa.and_(orders.c.isopen == 1,
+                                 users.c.id==orders.c.user_id),
+                lazy='subquery', order_by=orders.c.id),
+            closed_orders=relationship(
+                Order,
+                primaryjoin=sa.and_(orders.c.isopen == 0,
+                                 users.c.id==orders.c.user_id),
+                lazy='subquery', order_by=orders.c.id)))
+        q = create_session().query(User).order_by(User.id)
+
+        def go():
+            eq_([
+                User(id=7,
+                     addresses=[
+                       Address(id=1)],
+                     open_orders=[Order(id=3,
+                                        items=[
+                                          Item(id=3),
+                                          Item(id=4),
+                                          Item(id=5)])],
+                     closed_orders=[Order(id=1,
+                                          items=[
+                                            Item(id=1),
+                                            Item(id=2),
+                                            Item(id=3)]),
+                                    Order(id=5,
+                                          items=[
+                                            Item(id=5)])]),
+                User(id=8,
+                     addresses=[
+                       Address(id=2),
+                       Address(id=3),
+                       Address(id=4)],
+                     open_orders = [],
+                     closed_orders = []),
+                User(id=9,
+                     addresses=[
+                       Address(id=5)],
+                     open_orders=[
+                       Order(id=4,
+                             items=[
+                               Item(id=1),
+                               Item(id=5)])],
+                     closed_orders=[
+                       Order(id=2,
+                             items=[
+                               Item(id=1),
+                               Item(id=2),
+                               Item(id=3)])]),
+                User(id=10)
+            ], q.all())
+        self.assert_sql_count(testing.db, go, 6)
+
+    @testing.fails_on('maxdb', 'FIXME: unknown')
+    @testing.resolve_artifact_names
+    def test_limit(self):
+        """Limit operations combined with lazy-load relationships."""
+
+        mapper(Item, items)
+        mapper(Order, orders, properties={
+            'items':relationship(Item, secondary=order_items, lazy='subquery',
+                order_by=items.c.id)
+        })
+        mapper(User, users, properties={
+            'addresses':relationship(mapper(Address, addresses), 
+                            lazy='subquery', 
+                            order_by=addresses.c.id),
+            'orders':relationship(Order, lazy=True, order_by=orders.c.id)
+        })
+
+        sess = create_session()
+        q = sess.query(User)
+
+        l = q.order_by(User.id).limit(2).offset(1).all()
+        eq_(self.static.user_all_result[1:3], l)
+
+        sess = create_session()
+        l = q.order_by(sa.desc(User.id)).limit(2).offset(2).all()
+        eq_(list(reversed(self.static.user_all_result[0:2])), l)
+
+    @testing.resolve_artifact_names
+    def test_one_to_many_scalar(self):
+        mapper(User, users, properties = dict(
+            address = relationship(mapper(Address, addresses), 
+                                    lazy='subquery', uselist=False)
+        ))
+        q = create_session().query(User)
+
+        def go():
+            l = q.filter(users.c.id == 7).all()
+            eq_([User(id=7, address=Address(id=1))], l)
+        self.assert_sql_count(testing.db, go, 2)
+
+    @testing.fails_on('maxdb', 'FIXME: unknown')
+    @testing.resolve_artifact_names
+    def test_many_to_one(self):
+        mapper(Address, addresses, properties = dict(
+            user = relationship(mapper(User, users), lazy='subquery')
+        ))
+        sess = create_session()
+        q = sess.query(Address)
+
+        def go():
+            a = q.filter(addresses.c.id==1).one()
+            is_not_(a.user, None)
+            u1 = sess.query(User).get(7)
+            is_(a.user, u1)
+        self.assert_sql_count(testing.db, go, 2)
+
+    @testing.resolve_artifact_names
+    def test_double_with_aggregate(self):
+        max_orders_by_user = sa.select([sa.func.max(orders.c.id).label('order_id')],
+                                       group_by=[orders.c.user_id]
+                                     ).alias('max_orders_by_user')
+
+        max_orders = orders.select(orders.c.id==max_orders_by_user.c.order_id).\
+                                alias('max_orders')
+
+        mapper(Order, orders)
+        mapper(User, users, properties={
+               'orders':relationship(Order, backref='user', lazy='subquery',
+                                            order_by=orders.c.id),
+               'max_order':relationship(
+                                mapper(Order, max_orders, non_primary=True), 
+                                lazy='subquery', uselist=False)
+               })
+
+        q = create_session().query(User)
+
+        def go():
+            eq_([
+                User(id=7, orders=[
+                        Order(id=1),
+                        Order(id=3),
+                        Order(id=5),
+                    ],
+                    max_order=Order(id=5)
+                ),
+                User(id=8, orders=[]),
+                User(id=9, orders=[Order(id=2),Order(id=4)],
+                    max_order=Order(id=4)
+                ),
+                User(id=10),
+            ], q.order_by(User.id).all())
+        self.assert_sql_count(testing.db, go, 3)
+
+    @testing.resolve_artifact_names 
+    def test_uselist_false_warning(self):
+        """test that multiple rows received by a 
+        uselist=False raises a warning."""
+
+        mapper(User, users, properties={
+            'order':relationship(Order, uselist=False)
+        })
+        mapper(Order, orders)
+        s = create_session()
+        assert_raises(sa.exc.SAWarning,
+                s.query(User).options(subqueryload(User.order)).all)
+
     # TODO: all the tests in test_eager_relations
     
     # TODO: ensure state stuff works out OK, existing objects/collections