]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Check existing CTE for an alias name when rendering FROM clause
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Mar 2018 15:45:39 +0000 (10:45 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Mar 2018 15:54:35 +0000 (10:54 -0500)
Fixed bug in CTE rendering where a :class:`.CTE` that was also turned into
an :class:`.Alias` would not render its "ctename AS aliasname" clause
appropriately if there were more than one reference to the CTE in a FROM
clause.

Change-Id: If8cff27a2f4faa5eceb59aa86398db6edb3b9e72
Fixes: #4204
doc/build/changelog/unreleased_12/4204.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
test/sql/test_cte.py

diff --git a/doc/build/changelog/unreleased_12/4204.rst b/doc/build/changelog/unreleased_12/4204.rst
new file mode 100644 (file)
index 0000000..424a025
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 4204
+
+    Fixed bug in CTE rendering where a :class:`.CTE` that was also turned into
+    an :class:`.Alias` would not render its "ctename AS aliasname" clause
+    appropriately if there were more than one reference to the CTE in a FROM
+    clause.
index be41e80c5bc2ab5045a408667c30b12c057037be..438484f338bbca3a8116b3b9fd1680ec16bcdb6a 100644 (file)
@@ -1355,12 +1355,13 @@ class SQLCompiler(Compiled):
         else:
             cte_name = cte.name
 
+        is_new_cte = True
         if cte_name in self.ctes_by_name:
             existing_cte = self.ctes_by_name[cte_name]
             # we've generated a same-named CTE that we are enclosed in,
             # or this is the same CTE.  just return the name.
             if cte in existing_cte._restates or cte is existing_cte:
-                return self.preparer.format_alias(cte, cte_name)
+                is_new_cte = False
             elif existing_cte in cte._restates:
                 # we've generated a same-named CTE that is
                 # enclosed in us - we take precedence, so
@@ -1372,67 +1373,72 @@ class SQLCompiler(Compiled):
                     "the same name: %r" %
                     cte_name)
 
-        self.ctes_by_name[cte_name] = cte
-
-        # look for embedded DML ctes and propagate autocommit
-        if 'autocommit' in cte.element._execution_options and \
-                'autocommit' not in self.execution_options:
-            self.execution_options = self.execution_options.union(
-                {"autocommit": cte.element._execution_options['autocommit']})
-
-        if cte._cte_alias is not None:
-            orig_cte = cte._cte_alias
-            if orig_cte not in self.ctes:
-                self.visit_cte(orig_cte, **kwargs)
-            cte_alias_name = cte._cte_alias.name
-            if isinstance(cte_alias_name, elements._truncated_label):
-                cte_alias_name = self._truncated_identifier(
-                    "alias", cte_alias_name)
-        else:
-            orig_cte = cte
-            cte_alias_name = None
-        if not cte_alias_name and cte not in self.ctes:
-            if cte.recursive:
-                self.ctes_recursive = True
-            text = self.preparer.format_alias(cte, cte_name)
-            if cte.recursive:
-                if isinstance(cte.original, selectable.Select):
-                    col_source = cte.original
-                elif isinstance(cte.original, selectable.CompoundSelect):
-                    col_source = cte.original.selects[0]
-                else:
-                    assert False
-                recur_cols = [c for c in
-                              util.unique_list(col_source.inner_columns)
-                              if c is not None]
+        if asfrom or is_new_cte:
+            if cte._cte_alias is not None:
+                pre_alias_cte = cte._cte_alias
+                cte_pre_alias_name = cte._cte_alias.name
+                if isinstance(cte_pre_alias_name, elements._truncated_label):
+                    cte_pre_alias_name = self._truncated_identifier(
+                        "alias", cte_pre_alias_name)
+            else:
+                pre_alias_cte = cte
+                cte_pre_alias_name = None
+
+        if is_new_cte:
+            self.ctes_by_name[cte_name] = cte
+
+            # look for embedded DML ctes and propagate autocommit
+            if 'autocommit' in cte.element._execution_options and \
+                    'autocommit' not in self.execution_options:
+                self.execution_options = self.execution_options.union(
+                    {"autocommit":
+                     cte.element._execution_options['autocommit']})
+
+            if pre_alias_cte not in self.ctes:
+                self.visit_cte(pre_alias_cte, **kwargs)
+
+            if not cte_pre_alias_name and cte not in self.ctes:
+                if cte.recursive:
+                    self.ctes_recursive = True
+                text = self.preparer.format_alias(cte, cte_name)
+                if cte.recursive:
+                    if isinstance(cte.original, selectable.Select):
+                        col_source = cte.original
+                    elif isinstance(cte.original, selectable.CompoundSelect):
+                        col_source = cte.original.selects[0]
+                    else:
+                        assert False
+                    recur_cols = [c for c in
+                                  util.unique_list(col_source.inner_columns)
+                                  if c is not None]
 
-                text += "(%s)" % (", ".join(
-                    self.preparer.format_column(ident)
-                    for ident in recur_cols))
+                    text += "(%s)" % (", ".join(
+                        self.preparer.format_column(ident)
+                        for ident in recur_cols))
 
-            if self.positional:
-                kwargs['positional_names'] = self.cte_positional[cte] = []
+                if self.positional:
+                    kwargs['positional_names'] = self.cte_positional[cte] = []
 
-            text += " AS \n" + \
-                cte.original._compiler_dispatch(
-                    self, asfrom=True, **kwargs
-                )
+                text += " AS \n" + \
+                    cte.original._compiler_dispatch(
+                        self, asfrom=True, **kwargs
+                    )
 
-            if cte._suffixes:
-                text += " " + self._generate_prefixes(
-                    cte, cte._suffixes, **kwargs)
+                if cte._suffixes:
+                    text += " " + self._generate_prefixes(
+                        cte, cte._suffixes, **kwargs)
 
-            self.ctes[cte] = text
+                self.ctes[cte] = text
 
         if asfrom:
-            if cte_alias_name:
-                text = self.preparer.format_alias(cte, cte_alias_name)
+            if cte_pre_alias_name:
+                text = self.preparer.format_alias(cte, cte_pre_alias_name)
                 if self.preparer._requires_quotes(cte_name):
                     cte_name = self.preparer.quote(cte_name)
                 text += self.get_render_as_alias_suffix(cte_name)
+                return text
             else:
                 return self.preparer.format_alias(cte, cte_name)
-            return text
 
     def visit_alias(self, alias, asfrom=False, ashint=False,
                     iscrud=False,
index aadd470e8dfbff316d81bc04dd0704ef9312dfd0..af9c8ceb6fba3bccaffd86c12c7f122b110a4f70 100644 (file)
@@ -134,6 +134,21 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
                             "SELECT cte.x FROM cte"
                             )
 
+    def test_recursive_union_alias_one(self):
+        s1 = select([literal(0).label("x")])
+        cte = s1.cte(name="cte", recursive=True)
+        cte = cte.union_all(
+            select([cte.c.x + 1]).where(cte.c.x < 10)
+        ).alias("cr1")
+        s2 = select([cte])
+        self.assert_compile(s2,
+                            "WITH RECURSIVE cte(x) AS "
+                            "(SELECT :param_1 AS x UNION ALL "
+                            "SELECT cte.x + :x_1 AS anon_1 "
+                            "FROM cte WHERE cte.x < :x_2) "
+                            "SELECT cr1.x FROM cte AS cr1"
+                            )
+
     def test_recursive_union_no_alias_two(self):
         """
 
@@ -163,6 +178,26 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
                             "SELECT sum(t.n) AS sum_1 FROM t"
                             )
 
+    def test_recursive_union_alias_two(self):
+        """
+
+        """
+
+        # I know, this is the PG VALUES keyword,
+        # we're cheating here.  also yes we need the SELECT,
+        # sorry PG.
+        t = select([func.values(1).label("n")]).cte("t", recursive=True)
+        t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100)).alias('ta')
+        s = select([func.sum(t.c.n)])
+        self.assert_compile(s,
+                            "WITH RECURSIVE t(n) AS "
+                            "(SELECT values(:values_1) AS n "
+                            "UNION ALL SELECT t.n + :n_1 AS anon_1 "
+                            "FROM t "
+                            "WHERE t.n < :n_2) "
+                            "SELECT sum(ta.n) AS sum_1 FROM t AS ta"
+                            )
+
     def test_recursive_union_no_alias_three(self):
         # like test one, but let's refer to the CTE
         # in a sibling CTE.
@@ -187,6 +222,30 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
                             "SELECT cte.x, bar.x FROM cte, bar"
                             )
 
+    def test_recursive_union_alias_three(self):
+        # like test one, but let's refer to the CTE
+        # in a sibling CTE.
+
+        s1 = select([literal(0).label("x")])
+        cte = s1.cte(name="cte", recursive=True)
+
+        # can't do it here...
+        # bar = select([cte]).cte('bar')
+        cte = cte.union_all(
+            select([cte.c.x + 1]).where(cte.c.x < 10)
+        ).alias("cs1")
+        bar = select([cte]).cte('bar').alias("cs2")
+
+        s2 = select([cte, bar])
+        self.assert_compile(s2,
+                            "WITH RECURSIVE cte(x) AS "
+                            "(SELECT :param_1 AS x UNION ALL "
+                            "SELECT cte.x + :x_1 AS anon_1 "
+                            "FROM cte WHERE cte.x < :x_2), "
+                            "bar AS (SELECT cs1.x AS x FROM cte AS cs1) "
+                            "SELECT cs1.x, cs2.x FROM cte AS cs1, bar AS cs2"
+                            )
+
     def test_recursive_union_no_alias_four(self):
         # like test one and three, but let's refer
         # previous version of "cte".  here we test
@@ -234,6 +293,53 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             "FROM cte WHERE cte.x < :x_2) "
             "SELECT bar.x, cte.x FROM bar, cte")
 
+    def test_recursive_union_alias_four(self):
+        # like test one and three, but let's refer
+        # previous version of "cte".  here we test
+        # how the compiler resolves multiple instances
+        # of "cte".
+
+        s1 = select([literal(0).label("x")])
+        cte = s1.cte(name="cte", recursive=True)
+
+        bar = select([cte]).cte('bar').alias("cs1")
+        cte = cte.union_all(
+            select([cte.c.x + 1]).where(cte.c.x < 10)
+        ).alias("cs2")
+
+        # outer cte rendered first, then bar, which
+        # includes "inner" cte
+        s2 = select([cte, bar])
+        self.assert_compile(s2,
+                            "WITH RECURSIVE cte(x) AS "
+                            "(SELECT :param_1 AS x UNION ALL "
+                            "SELECT cte.x + :x_1 AS anon_1 "
+                            "FROM cte WHERE cte.x < :x_2), "
+                            "bar AS (SELECT cte.x AS x FROM cte) "
+                            "SELECT cs2.x, cs1.x FROM cte AS cs2, bar AS cs1"
+                            )
+
+        # bar rendered, only includes "inner" cte,
+        # "outer" cte isn't present
+        s2 = select([bar])
+        self.assert_compile(s2,
+                            "WITH RECURSIVE cte(x) AS "
+                            "(SELECT :param_1 AS x), "
+                            "bar AS (SELECT cte.x AS x FROM cte) "
+                            "SELECT cs1.x FROM bar AS cs1"
+                            )
+
+        # bar rendered, but then the "outer"
+        # cte is rendered.
+        s2 = select([bar, cte])
+        self.assert_compile(
+            s2, "WITH RECURSIVE bar AS (SELECT cte.x AS x FROM cte), "
+            "cte(x) AS "
+            "(SELECT :param_1 AS x UNION ALL "
+            "SELECT cte.x + :x_1 AS anon_1 "
+            "FROM cte WHERE cte.x < :x_2) "
+            "SELECT cs1.x, cs2.x FROM bar AS cs1, cte AS cs2")
+
     def test_conflicting_names(self):
         """test a flat out name conflict."""
 
@@ -290,6 +396,46 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
                             "FROM regional_sales WHERE "
                             "regional_sales.amount < :amount_2")
 
+    def test_union_cte_aliases(self):
+        orders = table('orders',
+                       column('region'),
+                       column('amount'),
+                       )
+
+        regional_sales = select([
+            orders.c.region,
+            orders.c.amount
+        ]).cte("regional_sales").alias("rs")
+
+        s = select(
+            [regional_sales.c.region]).where(
+            regional_sales.c.amount > 500
+        )
+
+        self.assert_compile(s,
+                            "WITH regional_sales AS "
+                            "(SELECT orders.region AS region, "
+                            "orders.amount AS amount FROM orders) "
+                            "SELECT rs.region "
+                            "FROM regional_sales AS rs WHERE "
+                            "rs.amount > :amount_1")
+
+        s = s.union_all(
+            select([regional_sales.c.region]).
+            where(
+                regional_sales.c.amount < 300
+            )
+        )
+        self.assert_compile(s,
+                            "WITH regional_sales AS "
+                            "(SELECT orders.region AS region, "
+                            "orders.amount AS amount FROM orders) "
+                            "SELECT rs.region FROM regional_sales AS rs "
+                            "WHERE rs.amount > :amount_1 "
+                            "UNION ALL SELECT rs.region "
+                            "FROM regional_sales AS rs WHERE "
+                            "rs.amount < :amount_2")
+
     def test_reserved_quote(self):
         orders = table('orders',
                        column('order'),
@@ -319,6 +465,60 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             '(SELECT "CTE".id AS id FROM "CTE") AS anon_2'
         )
 
+    def test_multi_subq_alias(self):
+        cte = select([literal(1).label("id")]).cte(name='cte1').alias("aa")
+
+        s1 = select([cte.c.id]).alias()
+        s2 = select([cte.c.id]).alias()
+
+        s = select([s1, s2])
+        self.assert_compile(
+            s,
+            "WITH cte1 AS (SELECT :param_1 AS id) "
+            "SELECT anon_1.id, anon_2.id FROM "
+            "(SELECT aa.id AS id FROM cte1 AS aa) AS anon_1, "
+            "(SELECT aa.id AS id FROM cte1 AS aa) AS anon_2"
+        )
+
+    def test_cte_refers_to_aliased_cte_twice(self):
+        # test issue #4204
+        a = table('a', column('id'))
+        b = table('b', column('id'), column('fid'))
+        c = table('c', column('id'), column('fid'))
+
+        cte1 = (
+            select([a.c.id])
+            .cte(name='cte1')
+        )
+
+        aa = cte1.alias('aa')
+
+        cte2 = (
+            select([b.c.id])
+            .select_from(b.join(aa, b.c.fid == aa.c.id))
+            .cte(name='cte2')
+        )
+
+        cte3 = (
+            select([c.c.id])
+            .select_from(c.join(aa, c.c.fid == aa.c.id))
+            .cte(name='cte3')
+        )
+
+        stmt = (
+            select([cte3.c.id, cte2.c.id])
+            .select_from(cte2.join(cte3, cte2.c.id == cte3.c.id))
+        )
+        self.assert_compile(
+            stmt,
+            "WITH cte1 AS (SELECT a.id AS id FROM a), "
+            "cte2 AS (SELECT b.id AS id FROM b "
+            "JOIN cte1 AS aa ON b.fid = aa.id), "
+            "cte3 AS (SELECT c.id AS id FROM c "
+            "JOIN cte1 AS aa ON c.fid = aa.id) "
+            "SELECT cte3.id, cte2.id FROM cte2 JOIN cte3 ON cte2.id = cte3.id"
+        )
+
     def test_named_alias_no_quote(self):
         cte = select([literal(1).label("id")]).cte(name='CTE')