From 339f3abdeb63bba68492da2aac903a98c32ca421 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 13 Jun 2012 18:19:26 -0400 Subject: [PATCH] - [bug] Repaired common table expression rendering to function correctly when the SELECT statement contains UNION or other compound expressions, courtesy btbuilder. [ticket:2490] --- CHANGES | 6 +++++ lib/sqlalchemy/sql/compiler.py | 25 +++++++++++++----- test/sql/test_compiler.py | 46 +++++++++++++++++++++++++++++++--- 3 files changed, 68 insertions(+), 9 deletions(-) diff --git a/CHANGES b/CHANGES index 0fb6a9f856..e1148ae189 100644 --- a/CHANGES +++ b/CHANGES @@ -39,6 +39,12 @@ CHANGES this breakage doesn't occur again. [ticket:2499] + - [bug] Repaired common table expression + rendering to function correctly when the + SELECT statement contains UNION or other + compound expressions, courtesy btbuilder. + [ticket:2490] + - engine - [bug] Fixed memory leak in C version of result proxy whereby DBAPIs which don't deliver diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index bf234fe5cc..0f368f3f7c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -562,6 +562,10 @@ class SQLCompiler(engine.Compiled): text += (cs._limit is not None or cs._offset is not None) and \ self.limit_clause(cs) or "" + if self.ctes and \ + compound_index==1 and not entry: + text = self._render_cte_clause() + text + self.stack.pop(-1) if asfrom and parens: return "(" + text + ")" @@ -958,12 +962,13 @@ class SQLCompiler(engine.Compiled): if self.ctes and \ compound_index==1 and not entry: - cte_text = self.get_cte_preamble(self.ctes_recursive) + " " - cte_text += ", \n".join( - [txt for txt in self.ctes.values()] - ) - cte_text += "\n " - text = cte_text + text + text = self._render_cte_clause() + text + #cte_text = self.get_cte_preamble(self.ctes_recursive) + " " + #cte_text += ", \n".join( + # [txt for txt in self.ctes.values()] + #) + #cte_text += "\n " + #text = cte_text + text self.stack.pop(-1) @@ -972,6 +977,14 @@ class SQLCompiler(engine.Compiled): else: return text + def _render_cte_clause(self): + cte_text = self.get_cte_preamble(self.ctes_recursive) + " " + cte_text += ", \n".join( + [txt for txt in self.ctes.values()] + ) + cte_text += "\n " + return cte_text + def get_cte_preamble(self, recursive): if recursive: return "WITH RECURSIVE" diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index feb7405db5..ca041ea9c2 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -315,7 +315,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): , dialect=default.DefaultDialect() ) - # using alternate keys. + # using alternate keys. # this will change with #2397 a, b, c = Column('a', Integer, key='b'), \ Column('b', Integer), \ @@ -348,7 +348,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): eq_(s.positiontup, ['a', 'b', 'c']) def test_nested_label_targeting(self): - """test nested anonymous label generation. + """test nested anonymous label generation. """ s1 = table1.select() @@ -1207,7 +1207,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT mytable.myid, mytable.name, mytable.description " "FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE", dialect=postgresql.dialect()) - + self.assert_compile( table1.select(table1.c.myid==7, for_update="read_nowait"), "SELECT mytable.myid, mytable.name, mytable.description " @@ -2450,6 +2450,46 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): dialect=mssql.dialect() ) + def test_cte_union(self): + orders = table('orders', + column('region'), + column('amount'), + ) + + regional_sales = select([ + orders.c.region, + orders.c.amount + ]).cte("regional_sales") + + s = select([regional_sales.c.region]).\ + where( + regional_sales.c.amount > 500 + ) + + self.assert_compile(s, + "WITH regional_sales AS " + "(SELECT orders.region AS region, " + "orders.amount AS amount FROM orders) " + "SELECT regional_sales.region " + "FROM regional_sales WHERE " + "regional_sales.amount > :amount_1") + + s = s.union_all( + select([regional_sales.c.region]).\ + where( + regional_sales.c.amount < 300 + ) + ) + self.assert_compile(s, + "WITH regional_sales AS " + "(SELECT orders.region AS region, " + "orders.amount AS amount FROM orders) " + "SELECT regional_sales.region FROM regional_sales " + "WHERE regional_sales.amount > :amount_1 " + "UNION ALL SELECT regional_sales.region " + "FROM regional_sales WHERE " + "regional_sales.amount < :amount_2") + def test_date_between(self): import datetime table = Table('dt', metadata, -- 2.47.2