]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added a full exercising test for all of #946, #947, #948, #949
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 27 Dec 2008 19:35:12 +0000 (19:35 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 27 Dec 2008 19:35:12 +0000 (19:35 +0000)
test/orm/eager_relations.py

index 0083d3848be4a84c4d068c9ad6d1fef7c18d82f8..ba9b7e02505cdbf8cc2875368c29aa0e90c4d536 100644 (file)
@@ -3,11 +3,12 @@
 import testenv; testenv.configure_for_tests()
 from testlib import sa, testing
 from sqlalchemy.orm import eagerload, deferred, undefer
-from testlib.sa import Table, Column, Integer, String, ForeignKey, and_
+from testlib.sa import Table, Column, Integer, String, Date, ForeignKey, and_, select, func
 from testlib.sa.orm import mapper, relation, create_session, lazyload
 from testlib.testing import eq_
 from testlib.assertsql import CompiledSQL
 from orm import _base, _fixtures
+import datetime
 
 class EagerTest(_fixtures.FixtureTest):
     run_inserts = 'once'
@@ -1312,6 +1313,157 @@ class SubqueryTest(_base.MappedTest):
             for t in (tags_table, users_table):
                 t.delete().execute()
 
+class CorrelatedSubqueryTest(_base.MappedTest):
+    """tests for #946, #947, #948.
+    
+    The "users" table is joined to "stuff", and the relation
+    would like to pull only the "stuff" entry with the most recent date.
+    
+    Exercises a variety of ways to configure this.
+    
+    """
+    
+    def define_tables(self, metadata):
+        users = Table('users', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('name', String(50))
+            )
+
+        stuff = Table('stuff', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('date', Date),
+            Column('user_id', Integer, ForeignKey('users.id')))
+    
+    @testing.resolve_artifact_names
+    def insert_data(self):
+        users.insert().execute(
+            {'id':1, 'name':'user1'},
+            {'id':2, 'name':'user2'},
+            {'id':3, 'name':'user3'},
+        )
+
+        stuff.insert().execute(
+            {'id':1, 'user_id':1, 'date':datetime.date(2007, 10, 15)},
+            {'id':2, 'user_id':1, 'date':datetime.date(2007, 12, 15)},
+            {'id':3, 'user_id':1, 'date':datetime.date(2007, 11, 15)},
+            {'id':4, 'user_id':2, 'date':datetime.date(2008, 1, 15)},
+            {'id':5, 'user_id':3, 'date':datetime.date(2007, 6, 15)},
+            {'id':6, 'user_id':3, 'date':datetime.date(2007, 3, 15)},
+        )
+        
+    
+    def test_labeled_on_date_noalias(self):
+        self._do_test('label', True, False)
+
+    def test_scalar_on_date_noalias(self):
+        self._do_test('scalar', True, False)
+
+    def test_plain_on_date_noalias(self):
+        self._do_test('none', True, False)
+
+    def test_labeled_on_limitid_noalias(self):
+        self._do_test('label', False, False)
+
+    def test_scalar_on_limitid_noalias(self):
+        self._do_test('scalar', False, False)
+
+    def test_plain_on_limitid_noalias(self):
+        self._do_test('none', False, False)
+
+    def test_labeled_on_date_alias(self):
+        self._do_test('label', True, True)
+
+    def test_scalar_on_date_alias(self):
+        self._do_test('scalar', True, True)
+
+    def test_plain_on_date_alias(self):
+        self._do_test('none', True, True)
+
+    def test_labeled_on_limitid_alias(self):
+        self._do_test('label', False, True)
+
+    def test_scalar_on_limitid_alias(self):
+        self._do_test('scalar', False, True)
+
+    def test_plain_on_limitid_alias(self):
+        self._do_test('none', False, True)
+        
+    @testing.resolve_artifact_names
+    def _do_test(self, labeled, ondate, aliasstuff):
+        class User(_base.ComparableEntity):
+            pass
+
+        class Stuff(_base.ComparableEntity):
+            pass
+        
+        mapper(Stuff, stuff)
+
+        if aliasstuff:
+            salias = stuff.alias()
+        else:
+            # if we don't alias the 'stuff' table within the correlated subquery, 
+            # it gets aliased in the eager load along with the "stuff" table to "stuff_1".
+            # but it's a scalar subquery, and this doesn't actually matter
+            salias = stuff
+
+        if ondate:
+            # the more 'relational' way to do this, join on the max date
+            stuff_view = select([func.max(salias.c.date).label('max_date')]).where(salias.c.user_id==users.c.id).correlate(users)
+        else:
+            # a common method with the MySQL crowd, which actually might perform better in some
+            # cases - subquery does a limit with order by DESC, join on the id
+            stuff_view = select([salias.c.id]).where(salias.c.user_id==users.c.id).correlate(users).order_by(salias.c.date.desc()).limit(1)
+
+        if labeled == 'label':
+            stuff_view = stuff_view.label('foo')
+        elif labeled == 'scalar':
+            stuff_view = stuff_view.as_scalar()
+
+        if ondate:
+            mapper(User, users, properties={
+                'stuff':relation(Stuff, primaryjoin=and_(users.c.id==stuff.c.user_id, stuff.c.date==stuff_view))
+            })
+        else:
+            mapper(User, users, properties={
+                'stuff':relation(Stuff, primaryjoin=and_(users.c.id==stuff.c.user_id, stuff.c.id==stuff_view))
+            })
+            
+        sess = create_session()
+        def go():
+            eq_(
+                sess.query(User).order_by(User.name).options(eagerload('stuff')).all(),
+                [
+                    User(name='user1', stuff=[Stuff(id=2)]),
+                    User(name='user2', stuff=[Stuff(id=4)]),
+                    User(name='user3', stuff=[Stuff(id=5)])
+                ]
+            )
+        self.assert_sql_count(testing.db, go, 1)
+    
+        sess = create_session()
+        def go():
+            eq_(
+                sess.query(User).order_by(User.name).first(),
+                User(name='user1', stuff=[Stuff(id=2)])
+            )
+        self.assert_sql_count(testing.db, go, 2)
+
+        sess = create_session()
+        def go():
+            eq_(
+                sess.query(User).order_by(User.name).options(eagerload('stuff')).first(),
+                User(name='user1', stuff=[Stuff(id=2)])
+            )
+        self.assert_sql_count(testing.db, go, 1)
+
+        sess = create_session()
+        def go():
+            eq_(
+                sess.query(User).filter(User.id==2).options(eagerload('stuff')).one(),
+                User(name='user2', stuff=[Stuff(id=4)])
+            )
+        self.assert_sql_count(testing.db, go, 1)
+
 
 if __name__ == '__main__':
     testenv.main()