]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add missing grouping for compound selects. fixes ticket #623
authorAnts Aasma <ants.aasma@gmail.com>
Tue, 26 Jun 2007 18:00:57 +0000 (18:00 +0000)
committerAnts Aasma <ants.aasma@gmail.com>
Tue, 26 Jun 2007 18:00:57 +0000 (18:00 +0000)
CHANGES
lib/sqlalchemy/sql.py
test/sql/query.py
test/sql/select.py

diff --git a/CHANGES b/CHANGES
index be73b34feaecfde82d2ef3fdefbb11d2bee0e2b9..66eeb3a144dc7d6a47e18cab3ec23f415fc9972e 100644 (file)
--- a/CHANGES
+++ b/CHANGES
       to polymorphic mappers that are using a straight "outerjoin"
       clause
 - sql
+    - fixed grouping of compound selects to give correct results. will break
+      on sqlite in some cases, but those cases were producing incorrect
+      results anyway, sqlite doesn't support grouped compound selects
+      [ticket:623]
     - fixed precedence of operators so that parenthesis are correctly applied
       [ticket:620]
     - calling <column>.in_() (i.e. with no arguments) will return 
index be1ed8a699edcdeb918c49f44b4a51d2b025cc0c..c86fc561a6b937fb5ceee4411bad5b0b33640cbf 100644 (file)
@@ -2427,6 +2427,8 @@ class _Grouping(ColumnElement):
         return self.elem._hide_froms()
     def _get_from_objects(self):
         return self.elem._get_from_objects()
+    def __getattr__(self, attr):
+        return getattr(self.elem, attr)
         
 class _Label(ColumnElement):
     """represent a label, as typically applied to any column-level element
@@ -2712,7 +2714,8 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
         self.is_scalar = False
         self.is_subquery = False
 
-        self.selects = selects
+        # unions group from left to right, so don't group first select
+        self.selects = [n and select.self_group(self) or select for n,select in enumerate(selects)]
 
         # some DBs do not like ORDER BY in the inner queries of a UNION, etc.
         for s in selects:
@@ -2945,6 +2948,8 @@ class Select(_SelectBaseMixin, FromClause):
             self.__hide_froms.add(f)
 
     def self_group(self, against=None):
+        if isinstance(against, CompoundSelect):
+            return self
         return _Grouping(self)
     
     def append_whereclause(self, whereclause):
index f7c38eb87d36d55a4528799a2a8e16899175a258..632246fad97f36e66333f9b44a52c4e801fdf6cf 100644 (file)
@@ -624,6 +624,29 @@ class CompoundTest(PersistTest):
         assert e.execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
         assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
 
+    @testbase.unsupported('sqlite', 'mysql', 'oracle')
+    def test_except_style3(self):
+        # aaa, bbb, ccc - (aaa, bbb, ccc - (ccc)) = ccc
+        e = except_(
+            select([t1.c.col3]), # aaa, bbb, ccc
+            except_(
+                select([t2.c.col3]), # aaa, bbb, ccc
+                select([t3.c.col3], t3.c.col3 == 'ccc'), #ccc
+            )
+        )
+        self.assertEquals(e.execute().fetchall(), [('ccc',)])
+
+    @testbase.unsupported('sqlite', 'mysql', 'oracle')
+    def test_union_union_all(self):
+        e = union_all(
+            select([t1.c.col3]),
+            union(
+                select([t1.c.col3]),
+                select([t1.c.col3]),
+            )
+        )
+        self.assertEquals(e.execute().fetchall(), [('aaa',),('bbb',),('ccc',),('aaa',),('bbb',),('ccc',)])
+
     @testbase.unsupported('mysql')
     def test_composite(self):
         u = intersect(
index 10fa631f0b5482a7f422dea4edaaa20e03b45d53..4d3eb4ad70b682c1172563c73a4f739b37878176 100644 (file)
@@ -663,6 +663,33 @@ FROM myothertable ORDER BY myid \
 WHERE mytable.name = :mytable_name GROUP BY mytable.myid, mytable.name UNION SELECT mytable.myid, mytable.name, mytable.description \
 FROM mytable WHERE mytable.name = :mytable_name_1"
             )
+    
+    def test_compound_select_grouping(self):
+            self.runtest(
+                union_all(
+                    select([table1.c.myid]),
+                    union(
+                        select([table2.c.otherid]),
+                        select([table3.c.userid]),
+                    )
+                )
+                ,
+                "SELECT mytable.myid FROM mytable UNION ALL (SELECT myothertable.otherid FROM myothertable UNION \
+SELECT thirdtable.userid FROM thirdtable)"
+            )
+            # This doesn't need grouping, so don't group to not give sqlite unnecessarily hard time
+            self.runtest(
+                union(
+                    except_(
+                        select([table2.c.otherid]),
+                        select([table3.c.userid]),
+                    ),
+                    select([table1.c.myid])
+                )
+                ,
+                "SELECT myothertable.otherid FROM myothertable EXCEPT SELECT thirdtable.userid FROM thirdtable \
+UNION SELECT mytable.myid FROM mytable"
+            )
             
     def testouterjoin(self):
         # test an outer join.  the oracle module should take the ON clause of the join and