]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added "aliased joins" feature to query.filter_by()
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 10 Jun 2007 23:53:03 +0000 (23:53 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 10 Jun 2007 23:53:03 +0000 (23:53 +0000)
- further work on modernizing/cleaning up unit tests

CHANGES
lib/sqlalchemy/orm/query.py
test/orm/alltests.py
test/orm/fixtures.py [new file with mode: 0644]
test/orm/inheritance/polymorph2.py
test/orm/lazy_relations.py [new file with mode: 0644]
test/orm/mapper.py
test/orm/query.py

diff --git a/CHANGES b/CHANGES
index 9952137830d4c364191cdf14b97e8c8c57fdfa2a..1c2cc8d604cca7fd2d4fb7264f769d1f1d83e42c 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -7,6 +7,11 @@
       new replacements.
     - query.list() replaced with query.all()
     - removed ancient query.select_by_attributename() capability.
+    - added "aliased joins" positional argument to the front of 
+      filter_by(). this allows auto-creation of joins that are aliased
+      locally to the individual filter_by() call.  This allows the 
+      auto-construction of joins which cross the same paths but
+      are querying divergent criteria.
     - along with recent speedups to ResultProxy, total number of
       function calls significantly reduced for large loads.
       test/perf/masseagerload.py reports 0.4 as having the fewest number
index d4318049b42afa7b8a544329b3d7d1f3df04dc14..a7bf73b80a33e50cb1a6c41cfa8b045419147fb9 100644 (file)
@@ -255,22 +255,37 @@ class Query(object):
 
         import properties
         
-        clause = None
+        if len(args) > 1:
+            raise exceptions.InvalidRequestError("filter_by() takes either zero positional arguments, or one scalar or list argument indicating a property search path.")
+        if len(args) == 1:
+            path = args[0]
+            (join, joinpoint, alias) = self._join_to(path, outerjoin=False, start=self.mapper, create_aliases=True)
+            clause = None
+        else:
+            alias = None
+            join = None
+            clause = None
+            joinpoint = self._joinpoint
 
         for key, value in kwargs.iteritems():
-            prop = self._joinpoint.props[key]
+            prop = joinpoint.props[key]
             if isinstance(prop, properties.PropertyLoader):
                 c = self._with_lazy_criterion(value, prop, True) # & self.join_via(keys[:-1]) - use aliasized join feature
             else:
                 c = prop.compare(value) # & self.join_via(keys) - use aliasized join feature
+            if alias is not None:
+                sql_util.ClauseAdapter(alias).traverse(c)
             if clause is None:
                 clause =  c
             else:
                 clause &= c
         
-        return self.filter(clause)
+        if join is not None:
+            return self.select_from(join).filter(clause)
+        else:
+            return self.filter(clause)
 
-    def _join_to(self, prop, outerjoin=False, start=None):
+    def _join_to(self, prop, outerjoin=False, start=None, create_aliases=False):
         if start is None:
             start = self._joinpoint
         
@@ -293,6 +308,7 @@ class Query(object):
         FindJoinedTables().traverse(clause)
             
         mapper = start
+        alias = None
         for key in keys:
             prop = mapper.props[key]
             if prop._is_self_referential():
@@ -307,12 +323,36 @@ class Query(object):
                         clause = clause.outerjoin(prop.select_table, prop.get_join(mapper))
                 else:
                     if prop.secondary:
-                        clause = clause.join(prop.secondary, prop.get_join(mapper, primary=True, secondary=False))
-                        clause = clause.join(prop.select_table, prop.get_join(mapper, primary=False))
+                        if create_aliases:
+                            join = prop.get_join(mapper, primary=True, secondary=False).copy_container()
+                            secondary_alias = prop.secondary.alias()
+                            if alias is not None:
+                                sql_util.ClauseAdapter(alias).traverse(join)
+                            sql_util.ClauseAdapter(secondary_alias).traverse(join)
+                            clause = clause.join(secondary_alias, join)
+                            alias = prop.select_table.alias()
+                            join = prop.get_join(mapper, primary=False).copy_container()
+                            sql_util.ClauseAdapter(secondary_alias).traverse(join)
+                            sql_util.ClauseAdapter(alias).traverse(join)
+                            clause = clause.join(alias, join)
+                        else:
+                            clause = clause.join(prop.secondary, prop.get_join(mapper, primary=True, secondary=False))
+                            clause = clause.join(prop.select_table, prop.get_join(mapper, primary=False))
                     else:
-                        clause = clause.join(prop.select_table, prop.get_join(mapper))
+                        if create_aliases:
+                            join = prop.get_join(mapper).copy_container()
+                            if alias is not None:
+                                sql_util.ClauseAdapter(alias).traverse(join)
+                            alias = prop.select_table.alias()
+                            sql_util.ClauseAdapter(alias).traverse(join)
+                            clause = clause.join(alias, join)
+                        else:
+                            clause = clause.join(prop.select_table, prop.get_join(mapper))
             mapper = prop.mapper
-        return (clause, mapper)
+        if create_aliases:
+            return (clause, mapper, alias)
+        else:
+            return (clause, mapper)
 
     def _generative_col_aggregate(self, col, func):
         """apply the given aggregate function to the query and return the newly
index 395b402bdb2c31f9f46ab67e128cce5ef68c8d5c..133d5176618ec7477ee7952194cf86352d4116c2 100644 (file)
@@ -7,6 +7,7 @@ def suite():
     modules_to_test = (
        'orm.attributes',
            'orm.query',
+           'orm.lazy_relations',
         'orm.mapper',
         'orm.generative',
         'orm.lazytest1',
diff --git a/test/orm/fixtures.py b/test/orm/fixtures.py
new file mode 100644 (file)
index 0000000..b79fdb5
--- /dev/null
@@ -0,0 +1,152 @@
+from sqlalchemy import *
+
+class Base(object):
+    def __init__(self, **kwargs):
+        for k in kwargs:
+            setattr(self, k, kwargs[k])
+            
+    def __ne__(self, other):
+        return not self.__eq__(other)
+        
+    def __eq__(self, other):
+        """'passively' compare this object to another.
+        
+        only look at attributes that are present on the source object.
+        
+        """
+        # use __dict__ to avoid instrumented properties
+        for attr in self.__dict__.keys():
+            if attr[0] == '_':
+                continue
+            value = getattr(self, attr)
+            if hasattr(value, '__iter__') and not isinstance(value, basestring):
+                if len(value) == 0:
+                    continue
+                for (us, them) in zip(value, getattr(other, attr)):
+                    if us != them:
+                        return False
+                else:
+                    continue
+            else:
+                if value is not None:
+                    if value != getattr(other, attr):
+                        return False
+        else:
+            return True
+
+class User(Base):pass
+class Order(Base):pass
+class Item(Base):pass
+class Keyword(Base):pass
+class Address(Base):pass
+
+metadata = MetaData()
+
+users = Table('users', metadata,
+    Column('id', Integer, primary_key=True),
+    Column('name', String(30), nullable=False))
+
+orders = Table('orders', metadata,
+    Column('id', Integer, primary_key=True),
+    Column('user_id', None, ForeignKey('users.id')),
+    Column('address_id', None, ForeignKey('addresses.id')),
+    Column('description', String(30)),
+    Column('isopen', Integer)
+    )
+
+addresses = Table('addresses', metadata, 
+    Column('id', Integer, primary_key=True),
+    Column('user_id', None, ForeignKey('users.id')),
+    Column('email_address', String(50), nullable=False))
+
+items = Table('items', metadata, 
+    Column('id', Integer, primary_key=True),
+    Column('description', String(30), nullable=False)
+    )
+
+order_items = Table('order_items', metadata,
+    Column('item_id', None, ForeignKey('items.id')),
+    Column('order_id', None, ForeignKey('orders.id')))
+
+item_keywords = Table('item_keywords', metadata, 
+    Column('item_id', None, ForeignKey('items.id')),
+    Column('keyword_id', None, ForeignKey('keywords.id')))
+
+keywords = Table('keywords', metadata, 
+    Column('id', Integer, primary_key=True),
+    Column('name', String(30), nullable=False)
+    )
+
+def install_fixture_data():
+    users.insert().execute(
+        dict(id = 7, name = 'jack'),
+        dict(id = 8, name = 'ed'),
+        dict(id = 9, name = 'fred'),
+        dict(id = 10, name = 'chuck'),
+
+    )
+    addresses.insert().execute(
+        dict(id = 1, user_id = 7, email_address = "jack@bean.com"),
+        dict(id = 2, user_id = 8, email_address = "ed@wood.com"),
+        dict(id = 3, user_id = 8, email_address = "ed@bettyboop.com"),
+        dict(id = 4, user_id = 8, email_address = "ed@lala.com"),
+        dict(id = 5, user_id = 9, email_address = "fred@fred.com"),
+    )
+    orders.insert().execute(
+        dict(id = 1, user_id = 7, description = 'order 1', isopen=0, address_id=1),
+        dict(id = 2, user_id = 9, description = 'order 2', isopen=0, address_id=4),
+        dict(id = 3, user_id = 7, description = 'order 3', isopen=1, address_id=1),
+        dict(id = 4, user_id = 9, description = 'order 4', isopen=1, address_id=4),
+        dict(id = 5, user_id = 7, description = 'order 5', isopen=0, address_id=1)
+    )
+    items.insert().execute(
+        dict(id=1, description='item 1'),
+        dict(id=2, description='item 2'),
+        dict(id=3, description='item 3'),
+        dict(id=4, description='item 4'),
+        dict(id=5, description='item 5'),
+    )
+    order_items.insert().execute(
+        dict(item_id=1, order_id=1),
+        dict(item_id=2, order_id=1),
+        dict(item_id=3, order_id=1),
+
+        dict(item_id=1, order_id=2),
+        dict(item_id=2, order_id=2),
+        dict(item_id=3, order_id=2),
+
+        dict(item_id=3, order_id=3),
+        dict(item_id=4, order_id=3),
+        dict(item_id=5, order_id=3),
+
+        dict(item_id=1, order_id=4),
+        dict(item_id=5, order_id=4),
+
+        dict(item_id=5, order_id=5),
+    )
+    keywords.insert().execute(
+        dict(id=1, name='blue'),
+        dict(id=2, name='red'),
+        dict(id=3, name='green'),
+        dict(id=4, name='big'),
+        dict(id=5, name='small'),
+        dict(id=6, name='round'),
+        dict(id=7, name='square')
+    )
+
+    # this many-to-many table has the keywords inserted
+    # in primary key order, to appease the unit tests.
+    # this is because postgres, oracle, and sqlite all support 
+    # true insert-order row id, but of course our pal MySQL does not,
+    # so the best it can do is order by, well something, so there you go.
+    item_keywords.insert().execute(
+        dict(keyword_id=2, item_id=1),
+        dict(keyword_id=2, item_id=2),
+        dict(keyword_id=4, item_id=1),
+        dict(keyword_id=6, item_id=1),
+        dict(keyword_id=5, item_id=2),
+        dict(keyword_id=3, item_id=3),
+        dict(keyword_id=4, item_id=3),
+        dict(keyword_id=7, item_id=2),
+        dict(keyword_id=6, item_id=3)
+    )
index ac47ff33cee1df9081bc37fd27d638b668c45a4c..a678994493d66974e433b702a76863bd5f9d8ab5 100644 (file)
@@ -205,20 +205,8 @@ class RelationTest3(testbase.ORMTest):
            Column('data', String(30))
            )
 
-    def testrelationonbaseclass_j1_nodata(self):
-       self.do_test("join1", False)
-    def testrelationonbaseclass_j2_nodata(self):
-       self.do_test("join2", False)
-    def testrelationonbaseclass_j1_data(self):
-       self.do_test("join1", True)
-    def testrelationonbaseclass_j2_data(self):
-       self.do_test("join2", True)
-    def testrelationonbaseclass_j3_nodata(self):
-       self.do_test("join3", False)
-    def testrelationonbaseclass_j3_data(self):
-       self.do_test("join3", True)
-
-    def do_test(self, jointype="join1", usedata=False):
+def generate_test(jointype="join1", usedata=False):
+    def do_test(self):
         class Person(AttrSettable):
             pass
         class Manager(Person):
@@ -239,12 +227,14 @@ class RelationTest3(testbase.ORMTest):
                 'manager':join(people, managers, people.c.person_id==managers.c.person_id),
                 'person':people.select(people.c.type=='person')
             }, None)
-        elif jointype == "join3":
+        elif jointype == 'join3':
+            poly_union = people.outerjoin(managers)
+        elif jointype == "join4":
             poly_union=None
-            
+        
         if usedata:
             mapper(Data, data)
-        
+    
         mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager')
         if usedata:
             mapper(Person, people, select_table=poly_union, polymorphic_identity='person', polymorphic_on=people.c.type,
@@ -275,7 +265,7 @@ class RelationTest3(testbase.ORMTest):
         sess.save(m)
         sess.save(p)
         sess.flush()
-        
+    
         sess.clear()
         p = sess.query(Person).get(p.person_id)
         p2 = sess.query(Person).get(p2.person_id)
@@ -288,7 +278,15 @@ class RelationTest3(testbase.ORMTest):
         if usedata:
             assert p.data.data == 'ps data'
             assert m.data.data == 'ms data'
+            
+    do_test.__name__ = 'test_relationonbaseclass_%s_%s' % (jointype, data and "nodata" or "data")            
+    return do_test
 
+for jointype in ["join1", "join2", "join3", "join4"]:
+    for data in (True, False):
+        func = generate_test(jointype, data)
+        setattr(RelationTest3, func.__name__, func)
+            
         
 class RelationTest4(testbase.ORMTest):
     def define_tables(self, metadata):
@@ -723,7 +721,7 @@ class GenerativeTest(testbase.AssertMixin):
 
         # test these twice because theres caching involved, as well previous issues that modified the polymorphic union
         for x in range(0, 2):
-            r = session.query(Person).filter_by(people.c.name.like('%2')).join('status').filter_by(name="active")
+            r = session.query(Person).filter(people.c.name.like('%2')).join('status').filter_by(name="active")
             assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]"
             r = session.query(Engineer).join('status').filter(people.c.name.in_('E2', 'E3', 'E4', 'M4', 'M2', 'M1') & (status.c.name=="active"))
             assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]"
diff --git a/test/orm/lazy_relations.py b/test/orm/lazy_relations.py
new file mode 100644 (file)
index 0000000..e0a6cee
--- /dev/null
@@ -0,0 +1,182 @@
+from sqlalchemy import *
+from sqlalchemy.orm import *
+import testbase
+
+from fixtures import *
+from query import QueryTest
+
+class LazyTest(QueryTest):
+    keep_mappers = False
+
+    def setup_mappers(self):
+        pass
+        
+    def test_basic(self):
+        mapper(User, users, properties={
+            'addresses':relation(mapper(Address, addresses), lazy=True)
+        })
+        sess = create_session()
+        q = sess.query(User)
+        assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(users.c.id == 7).all()
+
+    def test_bindstosession(self):
+        """test that lazy loaders use the mapper's contextual session if the parent instance
+        is not in a session, and that an error is raised if no contextual session"""
+        
+        from sqlalchemy.ext.sessioncontext import SessionContext
+        ctx = SessionContext(create_session)
+        m = mapper(User, users, properties = dict(
+            addresses = relation(mapper(Address, addresses, extension=ctx.mapper_extension), lazy=True)
+        ), extension=ctx.mapper_extension)
+        q = ctx.current.query(m)
+        u = q.filter(users.c.id == 7).first()
+        ctx.current.expunge(u)
+        assert User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')]) == u
+
+        clear_mappers()
+
+        mapper(User, users, properties={
+            'addresses':relation(mapper(Address, addresses), lazy=True)
+        })
+        try:
+            sess = create_session()
+            q = sess.query(User)
+            u = q.filter(users.c.id == 7).first()
+            sess.expunge(u)
+            assert User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')]) == u
+        except exceptions.InvalidRequestError, err:
+            assert "not bound to a Session, and no contextual session" in str(err)
+
+    def test_orderby(self):
+        mapper(User, users, properties = {
+            'addresses':relation(mapper(Address, addresses), lazy=True, order_by=addresses.c.email_address),
+        })
+        q = create_session().query(User)
+        assert [
+            User(id=7, addresses=[
+                Address(id=1)
+            ]), 
+            User(id=8, addresses=[
+                Address(id=3, email_address='ed@bettyboop.com'),
+                Address(id=4, email_address='ed@lala.com'),
+                Address(id=2, email_address='ed@wood.com')
+            ]), 
+            User(id=9, addresses=[
+                Address(id=5)
+            ]), 
+            User(id=10, addresses=[])
+        ] == q.all()
+        
+    def test_orderby_secondary(self):
+        """tests that a regular mapper select on a single table can order by a relation to a second table"""
+        
+        mapper(Address, addresses)
+
+        mapper(User, users, properties = dict(
+            addresses = relation(Address, lazy=True),
+        ))
+        q = create_session().query(User)
+        l = q.filter(users.c.id==addresses.c.user_id).order_by(addresses.c.email_address).all()
+        assert [
+            User(id=8, addresses=[
+                Address(id=2, email_address='ed@wood.com'),
+                Address(id=3, email_address='ed@bettyboop.com'),
+                Address(id=4, email_address='ed@lala.com'),
+            ]), 
+            User(id=9, addresses=[
+                Address(id=5)
+            ]), 
+            User(id=7, addresses=[
+                Address(id=1)
+            ]), 
+        ] == l
+
+    def test_orderby_desc(self):
+        mapper(Address, addresses)
+
+        mapper(User, users, properties = dict(
+            addresses = relation(Address, lazy=True,  order_by=[desc(addresses.c.email_address)]),
+        ))
+        sess = create_session()
+        assert [
+            User(id=7, addresses=[
+                Address(id=1)
+            ]), 
+            User(id=8, addresses=[
+                Address(id=2, email_address='ed@wood.com'),
+                Address(id=4, email_address='ed@lala.com'),
+                Address(id=3, email_address='ed@bettyboop.com'),
+            ]), 
+            User(id=9, addresses=[
+                Address(id=5)
+            ]), 
+            User(id=10, addresses=[])
+        ] == sess.query(User).all()
+
+    def test_no_orphan(self):
+        """test that a lazily loaded child object is not marked as an orphan"""
+
+        mapper(User, users, properties={
+            'addresses':relation(Address, cascade="all,delete-orphan", lazy=True)
+        })
+        mapper(Address, addresses)
+
+        sess = create_session()
+        user = sess.query(User).get(7)
+        assert getattr(User, 'addresses').hasparent(user.addresses[0], optimistic=True)
+        assert not class_mapper(Address)._is_orphan(user.addresses[0])
+
+    def test_limit(self):
+        """test limit operations combined with lazy-load relationships."""
+        
+        mapper(Item, items)
+        mapper(Order, orders, properties={
+            'items':relation(Item, secondary=order_items, lazy=True)
+        })
+        mapper(User, users, properties={
+            'addresses':relation(mapper(Address, addresses), lazy=True),
+            'orders':relation(Order, lazy=True)
+        })
+
+        sess = create_session()
+        q = sess.query(User)
+
+        if testbase.db.engine.name == 'mssql':
+            l = q.limit(2).all()
+            assert self.user_all_result[:2] == l
+        else:        
+            l = q.limit(2).offset(1).all()
+            print l
+            print self.user_all_result[1:3]
+            assert self.user_all_result[1:3] == l
+
+    def test_distinct(self):
+        mapper(Item, items)
+        mapper(Order, orders, properties={
+            'items':relation(Item, secondary=order_items, lazy=True)
+        })
+        mapper(User, users, properties={
+            'addresses':relation(mapper(Address, addresses), lazy=True),
+            'orders':relation(Order, lazy=True)
+        })
+
+        sess = create_session()
+        q = sess.query(User)
+
+        # use a union all to get a lot of rows to join against
+        u2 = users.alias('u2')
+        s = union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u')
+        print [key for key in s.c.keys()]
+        l = q.filter(s.c.u2_id==User.c.id).distinct().all()
+        assert self.user_all_result == l
+
+    def test_onetoone(self):
+        mapper(User, users, properties = dict(
+            address = relation(mapper(Address, addresses), lazy=True, uselist=False)
+        ))
+        q = create_session().query(User)
+        l = q.filter(users.c.id == 7).all()
+        assert [User(id=7, address=Address(id=1))] == l
+
+if __name__ == '__main__':
+    testbase.main()
index 43125982670c89b61485c48db113cc25348b69df..97de21712034750010182fbf21db47833a82dd56 100644 (file)
@@ -848,31 +848,7 @@ class NoLoadTest(MapperSuperTest):
             )
 
 
-class LazyTest(MapperSuperTest):
-
-    def testbasic(self):
-        """tests a basic one-to-many lazy load"""
-        m = mapper(User, users, properties = dict(
-            addresses = relation(mapper(Address, addresses), lazy = True)
-        ))
-        q = create_session().query(m)
-        l = q.select(users.c.user_id == 7)
-        self.assert_result(l, User,
-            {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
-            )
-
-    def testbindstosession(self):
-        ctx = SessionContext(create_session)
-        m = mapper(User, users, properties = dict(
-            addresses = relation(mapper(Address, addresses, extension=ctx.mapper_extension), lazy=True)
-        ), extension=ctx.mapper_extension)
-        q = ctx.current.query(m)
-        u = q.filter(users.c.user_id == 7).selectfirst()
-        ctx.current.expunge(u)
-        self.assert_result([u], User,
-            {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
-            )
-    
+class MapperExtensionTest(MapperSuperTest):
     def testcreateinstance(self):
         class Ext(MapperExtension):
             def create_instance(self, *args, **kwargs):
@@ -885,106 +861,9 @@ class LazyTest(MapperSuperTest):
         q = create_session().query(m)
         l = q.select();
         self.assert_result(l, User, *user_address_result)
-        
-    def testorderby(self):
-        m = mapper(Address, addresses)
-
-        m = mapper(User, users, properties = dict(
-            addresses = relation(m, lazy = True, order_by=addresses.c.email_address),
-        ))
-        q = create_session().query(m)
-        l = q.select()
-
-        self.assert_result(l, User,
-            {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])},
-            {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@bettyboop.com'}, {'email_address':'ed@lala.com'}, {'email_address':'ed@wood.com'}])},
-            {'user_id' : 9, 'addresses' : (Address, [])}
-            )
-
-    def testorderby_select(self):
-        """tests that a regular mapper select on a single table can order by a relation to a second table"""
-        m = mapper(Address, addresses)
-
-        m = mapper(User, users, properties = dict(
-            addresses = relation(m, lazy = True),
-        ))
-        q = create_session().query(m)
-        l = q.select(users.c.user_id==addresses.c.user_id, order_by=addresses.c.email_address)
-
-        self.assert_result(l, User,
-            {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@wood.com'}, {'email_address':'ed@bettyboop.com'}, {'email_address':'ed@lala.com'}, ])},
-            {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])},
-        )
-        
-    def testorderby_desc(self):
-        m = mapper(Address, addresses)
-
-        m = mapper(User, users, properties = dict(
-            addresses = relation(m, lazy = True, order_by=[desc(addresses.c.email_address)]),
-        ))
-        q = create_session().query(m)
-        l = q.select()
-
-        self.assert_result(l, User,
-            {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])},
-            {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@wood.com'}, {'email_address':'ed@lala.com'}, {'email_address':'ed@bettyboop.com'}])},
-            {'user_id' : 9, 'addresses' : (Address, [])},
-            )
-
-    def testorphanstate(self):
-        """test that a lazily loaded child object is not marked as an orphan"""
-        m = mapper(User, users, properties={
-            'addresses':relation(Address, cascade="all,delete-orphan", lazy=True)
-        })
-        mapper(Address, addresses)
-
-        q = create_session().query(m)
-        user = q.get(7)
-        assert getattr(User, 'addresses').hasparent(user.addresses[0], optimistic=True)
-        assert not class_mapper(Address)._is_orphan(user.addresses[0])
-        
-    def testlimit(self):
-        ordermapper = mapper(Order, orders, properties = dict(
-                items = relation(mapper(Item, orderitems), lazy = True)
-            ))
-
-        m = mapper(User, users, properties = dict(
-            addresses = relation(mapper(Address, addresses), lazy = True),
-            orders = relation(ordermapper, primaryjoin = users.c.user_id==orders.c.user_id, lazy = True),
-        ))
-        sess= create_session()
-        q = sess.query(m)
-        
-        if db.engine.name == 'mssql':
-            l = q.select(limit=2)
-            self.assert_result(l, User, *user_all_result[:2])
-        else:        
-            l = q.select(limit=2, offset=1)
-            self.assert_result(l, User, *user_all_result[1:3])
-
-        # use a union all to get a lot of rows to join against
-        u2 = users.alias('u2')
-        s = union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u')
-        print [key for key in s.c.keys()]
-        l = q.select(s.c.u2_user_id==User.c.user_id, distinct=True)
-        self.assert_result(l, User, *user_all_result)
-        
-        sess.clear()
-        clear_mappers()
-        m = mapper(Item, orderitems, properties = dict(
-                keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy = True),
-            ))
-        
-        l = sess.query(m).select((Item.c.item_name=='item 2') | (Item.c.item_name=='item 5') | (Item.c.item_name=='item 3'), order_by=[Item.c.item_id], limit=2)        
-        self.assert_result(l, Item, *[item_keyword_result[1], item_keyword_result[2]])
+    
+class LazyTest(MapperSuperTest):
 
-    def testonetoone(self):
-        m = mapper(User, users, properties = dict(
-            address = relation(mapper(Address, addresses), lazy = True, uselist = False)
-        ))
-        q = create_session().query(m)
-        l = q.select(users.c.user_id == 7)
-        self.assert_result(l, User, {'user_id':7, 'address' : (Address, {'address_id':1})})
 
     def testbackwardsonetoone(self):
         m = mapper(Address, addresses, properties = dict(
index b7b906466d29211bea989a102e143733f7f2e29c..9527b25d834a0cef1799d050814b902d20a53f72 100644 (file)
@@ -1,6 +1,7 @@
 from sqlalchemy import *
 from sqlalchemy.orm import *
 import testbase
+from fixtures import *
 
 class Base(object):
     def __init__(self, **kwargs):
@@ -42,132 +43,19 @@ class QueryTest(testbase.ORMTest):
     
     def setUpAll(self):
         super(QueryTest, self).setUpAll()
-        self.install_fixture_data()
+        install_fixture_data()
         self.setup_mappers()
-        
-    def define_tables(self, metadata):
-        global users, orders, addresses, items, order_items, item_keywords, keywords
-        
-        users = Table('users', metadata,
-            Column('id', Integer, primary_key=True),
-            Column('name', String(30), nullable=False))
-
-        orders = Table('orders', metadata,
-            Column('id', Integer, primary_key=True),
-            Column('user_id', None, ForeignKey('users.id')),
-            Column('address_id', None, ForeignKey('addresses.id')),
-            Column('description', String(30)),
-            Column('isopen', Integer)
-            )
-
-        addresses = Table('addresses', metadata, 
-            Column('id', Integer, primary_key=True),
-            Column('user_id', None, ForeignKey('users.id')),
-            Column('email_address', String(50), nullable=False))
-            
-        items = Table('items', metadata, 
-            Column('id', Integer, primary_key=True),
-            Column('description', String(30), nullable=False)
-            )
-
-        order_items = Table('order_items', metadata,
-            Column('item_id', None, ForeignKey('items.id')),
-            Column('order_id', None, ForeignKey('orders.id')))
-
-        item_keywords = Table('item_keywords', metadata, 
-            Column('item_id', None, ForeignKey('items.id')),
-            Column('keyword_id', None, ForeignKey('keywords.id')))
-
-        keywords = Table('keywords', metadata, 
-            Column('id', Integer, primary_key=True),
-            Column('name', String(30), nullable=False)
-            )
-        
-    def install_fixture_data(self):
-        users.insert().execute(
-            dict(id = 7, name = 'jack'),
-            dict(id = 8, name = 'ed'),
-            dict(id = 9, name = 'fred'),
-            dict(id = 10, name = 'chuck'),
-            
-        )
-        addresses.insert().execute(
-            dict(id = 1, user_id = 7, email_address = "jack@bean.com"),
-            dict(id = 2, user_id = 8, email_address = "ed@wood.com"),
-            dict(id = 3, user_id = 8, email_address = "ed@bettyboop.com"),
-            dict(id = 4, user_id = 8, email_address = "ed@lala.com"),
-            dict(id = 5, user_id = 9, email_address = "fred@fred.com"),
-        )
-        orders.insert().execute(
-            dict(id = 1, user_id = 7, description = 'order 1', isopen=0, address_id=1),
-            dict(id = 2, user_id = 9, description = 'order 2', isopen=0, address_id=4),
-            dict(id = 3, user_id = 7, description = 'order 3', isopen=1, address_id=1),
-            dict(id = 4, user_id = 9, description = 'order 4', isopen=1, address_id=4),
-            dict(id = 5, user_id = 7, description = 'order 5', isopen=0, address_id=1)
-        )
-        items.insert().execute(
-            dict(id=1, description='item 1'),
-            dict(id=2, description='item 2'),
-            dict(id=3, description='item 3'),
-            dict(id=4, description='item 4'),
-            dict(id=5, description='item 5'),
-        )
-        order_items.insert().execute(
-            dict(item_id=1, order_id=1),
-            dict(item_id=2, order_id=1),
-            dict(item_id=3, order_id=1),
-
-            dict(item_id=1, order_id=2),
-            dict(item_id=2, order_id=2),
-            dict(item_id=3, order_id=2),
-            dict(item_id=2, order_id=2),
-
-            dict(item_id=3, order_id=3),
-            dict(item_id=4, order_id=3),
-            dict(item_id=5, order_id=3),
-            
-            dict(item_id=5, order_id=4),
-            dict(item_id=1, order_id=4),
-            
-            dict(item_id=5, order_id=5),
-        )
-        keywords.insert().execute(
-            dict(id=1, name='blue'),
-            dict(id=2, name='red'),
-            dict(id=3, name='green'),
-            dict(id=4, name='big'),
-            dict(id=5, name='small'),
-            dict(id=6, name='round'),
-            dict(id=7, name='square')
-        )
-
-        # this many-to-many table has the keywords inserted
-        # in primary key order, to appease the unit tests.
-        # this is because postgres, oracle, and sqlite all support 
-        # true insert-order row id, but of course our pal MySQL does not,
-        # so the best it can do is order by, well something, so there you go.
-        item_keywords.insert().execute(
-            dict(keyword_id=2, item_id=1),
-            dict(keyword_id=2, item_id=2),
-            dict(keyword_id=4, item_id=1),
-            dict(keyword_id=6, item_id=1),
-            dict(keyword_id=5, item_id=2),
-            dict(keyword_id=3, item_id=3),
-            dict(keyword_id=4, item_id=3),
-            dict(keyword_id=7, item_id=2),
-            dict(keyword_id=6, item_id=3)
-        )
-
     
-    def setup_mappers(self):
-        global User, Order, Item, Keyword, Address
+    def tearDownAll(self):
+        clear_mappers()
+        super(QueryTest, self).tearDownAll()
+          
+    def define_tables(self, meta):
+        # a slight dirty trick here. 
+        meta.tables = metadata.tables
+        metadata.connect(meta.engine)
         
-        class User(Base):pass
-        class Order(Base):pass
-        class Item(Base):pass
-        class Keyword(Base):pass
-        class Address(Base):pass
-
+    def setup_mappers(self):
         mapper(User, users, properties={
             'addresses':relation(Address),
             'orders':relation(Order, backref='user'), # o2m, m2o
@@ -199,6 +87,30 @@ class QueryTest(testbase.ORMTest):
             User(id=10, addresses=[])
         ]
 
+    @property
+    def user_all_result(self):
+        return [
+            User(id=7, addresses=[
+                Address(id=1)
+            ], orders=[
+                Order(description='order 1', items=[Item(description='item 1'), Item(description='item 2'), Item(description='item 3')]),
+                Order(description='order 3'),
+                Order(description='order 5'),
+            ]), 
+            User(id=8, addresses=[
+                Address(id=2),
+                Address(id=3),
+                Address(id=4)
+            ]), 
+            User(id=9, addresses=[
+                Address(id=5)
+            ], orders=[
+                Order(description='order 2', items=[Item(description='item 1'), Item(description='item 2'), Item(description='item 3')]),
+                Order(description='order 4', items=[Item(description='item 1'), Item(description='item 5')]),
+            ]), 
+            User(id=10, addresses=[])
+        ]
+        
 class GetTest(QueryTest):
     def test_get(self):
         s = create_session()
@@ -289,6 +201,7 @@ class ParentTest(QueryTest):
     
 class JoinTest(QueryTest):
     def test_overlapping_paths(self):
+        # load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack)
         result = create_session().query(User).join(['orders', 'items']).filter_by(id=3).join(['orders','address']).filter_by(id=1).all()
         assert [User(id=7, name='jack')] == result
     
@@ -378,7 +291,7 @@ class InstancesTest(QueryTest):
 
     def test_multi_columns(self):
         sess = create_session()
-        (user7, user8, user9, user10) = sess.query(User).select()
+        (user7, user8, user9, user10) = sess.query(User).all()
         expected = [(user7, 1),
             (user8, 3),
             (user9, 1),
@@ -398,7 +311,7 @@ class InstancesTest(QueryTest):
     @testbase.unsupported('mysql') # only because of "+" operator requiring "concat" in mysql (fix #475)
     def test_two_columns(self):
         sess = create_session()
-        (user7, user8, user9, user10) = sess.query(User).select()
+        (user7, user8, user9, user10) = sess.query(User).all()
         expected = [
             (user7, 1, "Name:jack"),
             (user8, 3, "Name:ed"),
@@ -410,7 +323,32 @@ class InstancesTest(QueryTest):
         l = q.instances(s.execute(), "count", "concat")
         assert l == expected
 
-
+class FilterByTest(QueryTest):
+    def test_aliased(self):
+        """test automatic generation of aliased joins using filter_by()."""
+        
+        sess = create_session()
+        
+        # test a basic aliasized path
+        q = sess.query(User).filter_by(['addresses'], email_address='jack@bean.com')
+        assert [User(id=7)] == q.all()
+
+        # test two aliasized paths, one to 'orders' and the other to 'orders','items'.
+        # one row is returned because user 7 has order 3 and also has order 1 which has item 1
+        # this tests a o2m join and a m2m join.
+        q = sess.query(User).filter_by(['orders'], description="order 3").filter_by(['orders', 'items'], description="item 1")
+        assert q.count() == 1
+        assert [User(id=7)] == q.all()
+        
+        # test the control version - same joins but not aliased.  rows are not returned because order 3 does not have item 1
+        # addtionally by placing this test after the previous one, test that the "aliasing" step does not corrupt the
+        # join clauses that are cached by the relationship.
+        q = sess.query(User).join('orders').filter_by(description="order 3").join(['orders', 'items']).filter_by(description="item 1")
+        assert [] == q.all()
+        assert q.count() == 0
+        
+        
+        
 
 if __name__ == '__main__':
     testbase.main()