]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [bug] Repaired common table expression
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 13 Jun 2012 22:19:26 +0000 (18:19 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 13 Jun 2012 22:19:26 +0000 (18:19 -0400)
    rendering to function correctly when the
    SELECT statement contains UNION or other
    compound expressions, courtesy btbuilder.
    [ticket:2490]

CHANGES
lib/sqlalchemy/sql/compiler.py
test/sql/test_compiler.py

diff --git a/CHANGES b/CHANGES
index 0fb6a9f856bf50c2b70f9d63a3f28de478249314..e1148ae18991550fafb90d544b46177124b24d1c 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -39,6 +39,12 @@ CHANGES
     this breakage doesn't occur again.
     [ticket:2499]
 
+  - [bug] Repaired common table expression
+    rendering to function correctly when the 
+    SELECT statement contains UNION or other
+    compound expressions, courtesy btbuilder.  
+    [ticket:2490]
+
 - engine
   - [bug] Fixed memory leak in C version of
     result proxy whereby DBAPIs which don't deliver
index bf234fe5cc16df1055315ac5d870ca0237383f28..0f368f3f7c9ff1b109d4e5c523996174cb3c3cb9 100644 (file)
@@ -562,6 +562,10 @@ class SQLCompiler(engine.Compiled):
         text += (cs._limit is not None or cs._offset is not None) and \
                         self.limit_clause(cs) or ""
 
+        if self.ctes and \
+            compound_index==1 and not entry:
+            text = self._render_cte_clause() + text
+
         self.stack.pop(-1)
         if asfrom and parens:
             return "(" + text + ")"
@@ -958,12 +962,13 @@ class SQLCompiler(engine.Compiled):
 
         if self.ctes and \
             compound_index==1 and not entry:
-            cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
-            cte_text += ", \n".join(
-                [txt for txt in self.ctes.values()]
-            )
-            cte_text += "\n "
-            text = cte_text + text
+            text  = self._render_cte_clause() + text
+            #cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
+            #cte_text += ", \n".join(
+            #    [txt for txt in self.ctes.values()]
+            #)
+            #cte_text += "\n "
+            #text = cte_text + text
 
         self.stack.pop(-1)
 
@@ -972,6 +977,14 @@ class SQLCompiler(engine.Compiled):
         else:
             return text
 
+    def _render_cte_clause(self):
+        cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
+        cte_text += ", \n".join(
+            [txt for txt in self.ctes.values()]
+        )
+        cte_text += "\n "
+        return cte_text
+
     def get_cte_preamble(self, recursive):
         if recursive:
             return "WITH RECURSIVE"
index feb7405db5ddb7c826d4fdc5f06564e7b4d69ecf..ca041ea9c2991f98ec965b402aae249146450375 100644 (file)
@@ -315,7 +315,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             , dialect=default.DefaultDialect()
         )
 
-        # using alternate keys.  
+        # using alternate keys.
         # this will change with #2397
         a, b, c = Column('a', Integer, key='b'), \
                     Column('b', Integer), \
@@ -348,7 +348,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         eq_(s.positiontup, ['a', 'b', 'c'])
 
     def test_nested_label_targeting(self):
-        """test nested anonymous label generation.  
+        """test nested anonymous label generation.
 
         """
         s1 = table1.select()
@@ -1207,7 +1207,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT mytable.myid, mytable.name, mytable.description "
             "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE",
             dialect=postgresql.dialect())
-        
+
         self.assert_compile(
             table1.select(table1.c.myid==7, for_update="read_nowait"),
             "SELECT mytable.myid, mytable.name, mytable.description "
@@ -2450,6 +2450,46 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
                 dialect=mssql.dialect()
             )
 
+    def test_cte_union(self):
+        orders = table('orders', 
+            column('region'),
+            column('amount'),
+        )
+
+        regional_sales = select([
+                            orders.c.region,
+                            orders.c.amount
+                        ]).cte("regional_sales")
+
+        s = select([regional_sales.c.region]).\
+                where(
+                    regional_sales.c.amount > 500
+                )
+
+        self.assert_compile(s, 
+            "WITH regional_sales AS "
+            "(SELECT orders.region AS region, "
+            "orders.amount AS amount FROM orders) "
+            "SELECT regional_sales.region "
+            "FROM regional_sales WHERE "
+            "regional_sales.amount > :amount_1")
+
+        s = s.union_all(
+            select([regional_sales.c.region]).\
+                where(
+                    regional_sales.c.amount < 300
+                )
+        )
+        self.assert_compile(s, 
+            "WITH regional_sales AS "
+            "(SELECT orders.region AS region, "
+            "orders.amount AS amount FROM orders) "
+            "SELECT regional_sales.region FROM regional_sales "
+            "WHERE regional_sales.amount > :amount_1 "
+            "UNION ALL SELECT regional_sales.region "
+            "FROM regional_sales WHERE "
+            "regional_sales.amount < :amount_2")
+
     def test_date_between(self):
         import datetime
         table = Table('dt', metadata,