import testenv; testenv.configure_for_tests()
from testlib import sa, testing
from sqlalchemy.orm import eagerload, deferred, undefer
-from testlib.sa import Table, Column, Integer, String, ForeignKey, and_
+from testlib.sa import Table, Column, Integer, String, Date, ForeignKey, and_, select, func
from testlib.sa.orm import mapper, relation, create_session, lazyload
from testlib.testing import eq_
from testlib.assertsql import CompiledSQL
from orm import _base, _fixtures
+import datetime
class EagerTest(_fixtures.FixtureTest):
run_inserts = 'once'
for t in (tags_table, users_table):
t.delete().execute()
+class CorrelatedSubqueryTest(_base.MappedTest):
+ """tests for #946, #947, #948.
+
+ The "users" table is joined to "stuff", and the relation
+ would like to pull only the "stuff" entry with the most recent date.
+
+ Exercises a variety of ways to configure this.
+
+ """
+
+ def define_tables(self, metadata):
+ users = Table('users', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(50))
+ )
+
+ stuff = Table('stuff', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('date', Date),
+ Column('user_id', Integer, ForeignKey('users.id')))
+
+ @testing.resolve_artifact_names
+ def insert_data(self):
+ users.insert().execute(
+ {'id':1, 'name':'user1'},
+ {'id':2, 'name':'user2'},
+ {'id':3, 'name':'user3'},
+ )
+
+ stuff.insert().execute(
+ {'id':1, 'user_id':1, 'date':datetime.date(2007, 10, 15)},
+ {'id':2, 'user_id':1, 'date':datetime.date(2007, 12, 15)},
+ {'id':3, 'user_id':1, 'date':datetime.date(2007, 11, 15)},
+ {'id':4, 'user_id':2, 'date':datetime.date(2008, 1, 15)},
+ {'id':5, 'user_id':3, 'date':datetime.date(2007, 6, 15)},
+ {'id':6, 'user_id':3, 'date':datetime.date(2007, 3, 15)},
+ )
+
+
+ def test_labeled_on_date_noalias(self):
+ self._do_test('label', True, False)
+
+ def test_scalar_on_date_noalias(self):
+ self._do_test('scalar', True, False)
+
+ def test_plain_on_date_noalias(self):
+ self._do_test('none', True, False)
+
+ def test_labeled_on_limitid_noalias(self):
+ self._do_test('label', False, False)
+
+ def test_scalar_on_limitid_noalias(self):
+ self._do_test('scalar', False, False)
+
+ def test_plain_on_limitid_noalias(self):
+ self._do_test('none', False, False)
+
+ def test_labeled_on_date_alias(self):
+ self._do_test('label', True, True)
+
+ def test_scalar_on_date_alias(self):
+ self._do_test('scalar', True, True)
+
+ def test_plain_on_date_alias(self):
+ self._do_test('none', True, True)
+
+ def test_labeled_on_limitid_alias(self):
+ self._do_test('label', False, True)
+
+ def test_scalar_on_limitid_alias(self):
+ self._do_test('scalar', False, True)
+
+ def test_plain_on_limitid_alias(self):
+ self._do_test('none', False, True)
+
+ @testing.resolve_artifact_names
+ def _do_test(self, labeled, ondate, aliasstuff):
+ class User(_base.ComparableEntity):
+ pass
+
+ class Stuff(_base.ComparableEntity):
+ pass
+
+ mapper(Stuff, stuff)
+
+ if aliasstuff:
+ salias = stuff.alias()
+ else:
+ # if we don't alias the 'stuff' table within the correlated subquery,
+ # it gets aliased in the eager load along with the "stuff" table to "stuff_1".
+ # but it's a scalar subquery, and this doesn't actually matter
+ salias = stuff
+
+ if ondate:
+ # the more 'relational' way to do this, join on the max date
+ stuff_view = select([func.max(salias.c.date).label('max_date')]).where(salias.c.user_id==users.c.id).correlate(users)
+ else:
+ # a common method with the MySQL crowd, which actually might perform better in some
+ # cases - subquery does a limit with order by DESC, join on the id
+ stuff_view = select([salias.c.id]).where(salias.c.user_id==users.c.id).correlate(users).order_by(salias.c.date.desc()).limit(1)
+
+ if labeled == 'label':
+ stuff_view = stuff_view.label('foo')
+ elif labeled == 'scalar':
+ stuff_view = stuff_view.as_scalar()
+
+ if ondate:
+ mapper(User, users, properties={
+ 'stuff':relation(Stuff, primaryjoin=and_(users.c.id==stuff.c.user_id, stuff.c.date==stuff_view))
+ })
+ else:
+ mapper(User, users, properties={
+ 'stuff':relation(Stuff, primaryjoin=and_(users.c.id==stuff.c.user_id, stuff.c.id==stuff_view))
+ })
+
+ sess = create_session()
+ def go():
+ eq_(
+ sess.query(User).order_by(User.name).options(eagerload('stuff')).all(),
+ [
+ User(name='user1', stuff=[Stuff(id=2)]),
+ User(name='user2', stuff=[Stuff(id=4)]),
+ User(name='user3', stuff=[Stuff(id=5)])
+ ]
+ )
+ self.assert_sql_count(testing.db, go, 1)
+
+ sess = create_session()
+ def go():
+ eq_(
+ sess.query(User).order_by(User.name).first(),
+ User(name='user1', stuff=[Stuff(id=2)])
+ )
+ self.assert_sql_count(testing.db, go, 2)
+
+ sess = create_session()
+ def go():
+ eq_(
+ sess.query(User).order_by(User.name).options(eagerload('stuff')).first(),
+ User(name='user1', stuff=[Stuff(id=2)])
+ )
+ self.assert_sql_count(testing.db, go, 1)
+
+ sess = create_session()
+ def go():
+ eq_(
+ sess.query(User).filter(User.id==2).options(eagerload('stuff')).one(),
+ User(name='user2', stuff=[Stuff(id=4)])
+ )
+ self.assert_sql_count(testing.db, go, 1)
+
if __name__ == '__main__':
testenv.main()