]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Added some outerjoin() execution exercises to the query tests.
authorJason Kirtland <jek@discorporate.us>
Mon, 8 Oct 2007 02:17:07 +0000 (02:17 +0000)
committerJason Kirtland <jek@discorporate.us>
Mon, 8 Oct 2007 02:17:07 +0000 (02:17 +0000)
test/sql/query.py

index 1a08a0e304bf908c71aa003c49473ad03328a7ae..00eac66fd0630b0a5c96c6a2a1dadfea86763973 100644 (file)
@@ -761,7 +761,73 @@ class CompoundTest(PersistTest):
         wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
         found = self._fetchall_sorted(ua.select().execute())
         self.assertEquals(found, wanted)
+
+
+class JoinTest(PersistTest):
+    """Tests join execution."""
     
+    def setUpAll(self):
+        global metadata
+        global t1, t2, t3
+
+        metadata = MetaData(testbase.db)
+        t1 = Table('t1', metadata,
+                   Column('t1_id', Integer, primary_key=True),
+                   Column('name', String(32)))
+        t2 = Table('t2', metadata,
+                   Column('t2_id', Integer, primary_key=True),
+                   Column('t1_id', Integer, ForeignKey('t1.t1_id')),
+                   Column('name', String(32)))
+        t3 = Table('t3', metadata,
+                   Column('t3_id', Integer, primary_key=True),
+                   Column('t2_id', Integer, ForeignKey('t2.t2_id')),
+                   Column('name', String(32)))
+        metadata.drop_all()
+        metadata.create_all()
+
+        # t1.1 -> t2.1 -> t3.1
+        # t1.2 -> t2.2
+        # t1.3
+        t1.insert().execute([{'t1_id': i, 'name': 't1 #%s' % i}
+                             for i in (1, 2, 3)])
+        t2.insert().execute([{'t2_id': i, 't1_id': i, 'name': 't2 #%s' % i}
+                             for i in (1, 2)])
+        t3.insert().execute([{'t3_id': i, 't2_id': i, 'name': 't3 #%s' % i}
+                             for i in (1,)])
+
+    def tearDownAll(self):
+        metadata.drop_all()
+
+    def assertRows(self, statement, expected):
+        """Execute a statement and assert that rows returned equal expected."""
+        
+        found = exec_sorted(statement)
+        self.assertEquals(found, sorted(expected))
+
+    def test_outerjoin_x1(self):
+        expr_left = select(
+            [t1.c.t1_id, t2.c.t2_id],
+            from_obj=[t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id)])
+        self.assertRows(expr_left, [(1, 1), (2, 2), (3, None)])
+
+        expr_right = select(
+            [t1.c.t1_id, t2.c.t2_id],
+            from_obj=[t1.outerjoin(t2, t2.c.t1_id==t1.c.t1_id)])
+        self.assertRows(expr_right, [(1, 1), (2, 2), (3, None)])
+
+    def test_outerjoin_x2(self):
+        expr_left = select(
+            [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+            from_obj=[t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). \
+                         outerjoin(t3, t2.c.t2_id==t3.c.t3_id)])
+        self.assertRows(expr_left, [(1, 1, 1), (2, 2, None), (3, None, None)])
+
+        expr_right = select(
+            [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+            from_obj=[t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). \
+                         outerjoin(t3, t3.c.t2_id==t2.c.t2_id)])
+        self.assertRows(expr_right, [(1, 1, 1), (2, 2, None), (3, None, None)])
+
 
 class OperatorTest(PersistTest):
     def setUpAll(self):
@@ -787,6 +853,14 @@ class OperatorTest(PersistTest):
             select([flds.c.intcol % 3], order_by=flds.c.idcol).execute().fetchall(),
             [(2,),(1,)]
         )
-        
+
+
+def exec_sorted(statement, *args, **kw):
+    """Executes a statement and returns a sorted list plain tuple rows."""
+
+    return sorted([tuple(row)
+                   for row in statement.execute(*args, **kw).fetchall()])
+
+
 if __name__ == "__main__":
     testbase.main()