]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed bug in common table expression system where if the CTE were
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 31 Jul 2013 22:42:58 +0000 (18:42 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 31 Jul 2013 22:48:27 +0000 (18:48 -0400)
used only as an ``alias()`` construct, it would not render using the
WITH keyword.
[ticket:2783]

doc/build/changelog/changelog_07.rst
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/sql/test_cte.py

index b3b37861e4264ab4bbb5a502f15dcab4e215b6c0..59276c73a7bf5abf5bab5a15a8c89540b6a62068 100644 (file)
@@ -6,6 +6,14 @@
 .. changelog::
     :version: 0.7.11
 
+    .. change::
+        :tags: sql, bug, cte
+        :tickets: 2783
+
+        Fixed bug in common table expression system where if the CTE were
+        used only as an ``alias()`` construct, it would not render using the
+        WITH keyword.  Also in 0.8.3, 0.7.11.
+
     .. change::
         :tags: bug, sql
         :tickets: 2784
index 9dc56d1f09d57d6a77d6320ff98a3f660199ac29..2e6301f159ffc4dfd72c7a82851bcea90543ff26 100644 (file)
@@ -811,12 +811,17 @@ class SQLCompiler(engine.Compiled):
 
         self.ctes_by_name[cte_name] = cte
 
-        if cte.cte_alias:
-            if isinstance(cte.cte_alias, sql._truncated_label):
-                cte_alias = self._truncated_identifier("alias", cte.cte_alias)
-            else:
-                cte_alias = cte.cte_alias
-        if not cte.cte_alias and cte not in self.ctes:
+        if cte._cte_alias is not None:
+            orig_cte = cte._cte_alias
+            if orig_cte not in self.ctes:
+                self.visit_cte(orig_cte)
+            cte_alias_name = cte._cte_alias.name
+            if isinstance(cte_alias_name, sql._truncated_label):
+                cte_alias_name = self._truncated_identifier("alias", cte_alias_name)
+        else:
+            orig_cte = cte
+            cte_alias_name = None
+        if not cte_alias_name and cte not in self.ctes:
             if cte.recursive:
                 self.ctes_recursive = True
             text = self.preparer.format_alias(cte, cte_name)
@@ -839,9 +844,10 @@ class SQLCompiler(engine.Compiled):
                                 self, asfrom=True, **kwargs
                             )
             self.ctes[cte] = text
+
         if asfrom:
-            if cte.cte_alias:
-                text = self.preparer.format_alias(cte, cte_alias)
+            if cte_alias_name:
+                text = self.preparer.format_alias(cte, cte_alias_name)
                 text += " AS " + cte_name
             else:
                 return self.preparer.format_alias(cte, cte_name)
index c90a3dcb0212f75c92f349e84d891a53a7d8c74e..2868af2212f98c6aaaaf0b8a83c7e720d7824db5 100644 (file)
@@ -3764,10 +3764,10 @@ class CTE(Alias):
     def __init__(self, selectable,
                         name=None,
                         recursive=False,
-                        cte_alias=False,
+                        _cte_alias=None,
                         _restates=frozenset()):
         self.recursive = recursive
-        self.cte_alias = cte_alias
+        self._cte_alias = _cte_alias
         self._restates = _restates
         super(CTE, self).__init__(selectable, name=name)
 
@@ -3776,8 +3776,8 @@ class CTE(Alias):
             self.original,
             name=name,
             recursive=self.recursive,
-            cte_alias = self.name
-        )
+            _cte_alias=self,
+          )
 
     def union(self, other):
         return CTE(
index 59b347ccd203ebdaaec19b554c9cdc0a20510da5..71663ca789be4bfa6c3c3f57fcf1f1e029d2907b 100644 (file)
@@ -350,3 +350,32 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             checkpositional=('x', 'y'),
             dialect=dialect
         )
+
+    def test_all_aliases(self):
+        orders = table('order', column('order'))
+        s = select([orders.c.order]).cte("regional_sales")
+
+        r1 = s.alias()
+        r2 = s.alias()
+
+        s2 = select([r1, r2]).where(r1.c.order > r2.c.order)
+
+        self.assert_compile(
+            s2,
+            'WITH regional_sales AS (SELECT "order"."order" '
+            'AS "order" FROM "order") '
+            'SELECT anon_1."order", anon_2."order" '
+            'FROM regional_sales AS anon_1, '
+            'regional_sales AS anon_2 WHERE anon_1."order" > anon_2."order"'
+        )
+
+        s3 = select([orders]).select_from(orders.join(r1, r1.c.order == orders.c.order))
+
+        self.assert_compile(
+            s3,
+            'WITH regional_sales AS '
+            '(SELECT "order"."order" AS "order" '
+            'FROM "order")'
+            ' SELECT "order"."order" '
+            'FROM "order" JOIN regional_sales AS anon_1 ON anon_1."order" = "order"."order"'
+        )