From: Mike Bayer Date: Thu, 1 Mar 2018 15:45:39 +0000 (-0500) Subject: Check existing CTE for an alias name when rendering FROM clause X-Git-Tag: rel_1_3_0b1~244^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5f60dc649cde2525f5eb1e7008a75304603b751c;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Check existing CTE for an alias name when rendering FROM clause Fixed bug in CTE rendering where a :class:`.CTE` that was also turned into an :class:`.Alias` would not render its "ctename AS aliasname" clause appropriately if there were more than one reference to the CTE in a FROM clause. Change-Id: If8cff27a2f4faa5eceb59aa86398db6edb3b9e72 Fixes: #4204 --- diff --git a/doc/build/changelog/unreleased_12/4204.rst b/doc/build/changelog/unreleased_12/4204.rst new file mode 100644 index 0000000000..424a025843 --- /dev/null +++ b/doc/build/changelog/unreleased_12/4204.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, sql + :tickets: 4204 + + Fixed bug in CTE rendering where a :class:`.CTE` that was also turned into + an :class:`.Alias` would not render its "ctename AS aliasname" clause + appropriately if there were more than one reference to the CTE in a FROM + clause. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index be41e80c5b..438484f338 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1355,12 +1355,13 @@ class SQLCompiler(Compiled): else: cte_name = cte.name + is_new_cte = True if cte_name in self.ctes_by_name: existing_cte = self.ctes_by_name[cte_name] # we've generated a same-named CTE that we are enclosed in, # or this is the same CTE. just return the name. if cte in existing_cte._restates or cte is existing_cte: - return self.preparer.format_alias(cte, cte_name) + is_new_cte = False elif existing_cte in cte._restates: # we've generated a same-named CTE that is # enclosed in us - we take precedence, so @@ -1372,67 +1373,72 @@ class SQLCompiler(Compiled): "the same name: %r" % cte_name) - self.ctes_by_name[cte_name] = cte - - # look for embedded DML ctes and propagate autocommit - if 'autocommit' in cte.element._execution_options and \ - 'autocommit' not in self.execution_options: - self.execution_options = self.execution_options.union( - {"autocommit": cte.element._execution_options['autocommit']}) - - if cte._cte_alias is not None: - orig_cte = cte._cte_alias - if orig_cte not in self.ctes: - self.visit_cte(orig_cte, **kwargs) - cte_alias_name = cte._cte_alias.name - if isinstance(cte_alias_name, elements._truncated_label): - cte_alias_name = self._truncated_identifier( - "alias", cte_alias_name) - else: - orig_cte = cte - cte_alias_name = None - if not cte_alias_name and cte not in self.ctes: - if cte.recursive: - self.ctes_recursive = True - text = self.preparer.format_alias(cte, cte_name) - if cte.recursive: - if isinstance(cte.original, selectable.Select): - col_source = cte.original - elif isinstance(cte.original, selectable.CompoundSelect): - col_source = cte.original.selects[0] - else: - assert False - recur_cols = [c for c in - util.unique_list(col_source.inner_columns) - if c is not None] + if asfrom or is_new_cte: + if cte._cte_alias is not None: + pre_alias_cte = cte._cte_alias + cte_pre_alias_name = cte._cte_alias.name + if isinstance(cte_pre_alias_name, elements._truncated_label): + cte_pre_alias_name = self._truncated_identifier( + "alias", cte_pre_alias_name) + else: + pre_alias_cte = cte + cte_pre_alias_name = None + + if is_new_cte: + self.ctes_by_name[cte_name] = cte + + # look for embedded DML ctes and propagate autocommit + if 'autocommit' in cte.element._execution_options and \ + 'autocommit' not in self.execution_options: + self.execution_options = self.execution_options.union( + {"autocommit": + cte.element._execution_options['autocommit']}) + + if pre_alias_cte not in self.ctes: + self.visit_cte(pre_alias_cte, **kwargs) + + if not cte_pre_alias_name and cte not in self.ctes: + if cte.recursive: + self.ctes_recursive = True + text = self.preparer.format_alias(cte, cte_name) + if cte.recursive: + if isinstance(cte.original, selectable.Select): + col_source = cte.original + elif isinstance(cte.original, selectable.CompoundSelect): + col_source = cte.original.selects[0] + else: + assert False + recur_cols = [c for c in + util.unique_list(col_source.inner_columns) + if c is not None] - text += "(%s)" % (", ".join( - self.preparer.format_column(ident) - for ident in recur_cols)) + text += "(%s)" % (", ".join( + self.preparer.format_column(ident) + for ident in recur_cols)) - if self.positional: - kwargs['positional_names'] = self.cte_positional[cte] = [] + if self.positional: + kwargs['positional_names'] = self.cte_positional[cte] = [] - text += " AS \n" + \ - cte.original._compiler_dispatch( - self, asfrom=True, **kwargs - ) + text += " AS \n" + \ + cte.original._compiler_dispatch( + self, asfrom=True, **kwargs + ) - if cte._suffixes: - text += " " + self._generate_prefixes( - cte, cte._suffixes, **kwargs) + if cte._suffixes: + text += " " + self._generate_prefixes( + cte, cte._suffixes, **kwargs) - self.ctes[cte] = text + self.ctes[cte] = text if asfrom: - if cte_alias_name: - text = self.preparer.format_alias(cte, cte_alias_name) + if cte_pre_alias_name: + text = self.preparer.format_alias(cte, cte_pre_alias_name) if self.preparer._requires_quotes(cte_name): cte_name = self.preparer.quote(cte_name) text += self.get_render_as_alias_suffix(cte_name) + return text else: return self.preparer.format_alias(cte, cte_name) - return text def visit_alias(self, alias, asfrom=False, ashint=False, iscrud=False, diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index aadd470e8d..af9c8ceb6f 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -134,6 +134,21 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT cte.x FROM cte" ) + def test_recursive_union_alias_one(self): + s1 = select([literal(0).label("x")]) + cte = s1.cte(name="cte", recursive=True) + cte = cte.union_all( + select([cte.c.x + 1]).where(cte.c.x < 10) + ).alias("cr1") + s2 = select([cte]) + self.assert_compile(s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2) " + "SELECT cr1.x FROM cte AS cr1" + ) + def test_recursive_union_no_alias_two(self): """ @@ -163,6 +178,26 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT sum(t.n) AS sum_1 FROM t" ) + def test_recursive_union_alias_two(self): + """ + + """ + + # I know, this is the PG VALUES keyword, + # we're cheating here. also yes we need the SELECT, + # sorry PG. + t = select([func.values(1).label("n")]).cte("t", recursive=True) + t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100)).alias('ta') + s = select([func.sum(t.c.n)]) + self.assert_compile(s, + "WITH RECURSIVE t(n) AS " + "(SELECT values(:values_1) AS n " + "UNION ALL SELECT t.n + :n_1 AS anon_1 " + "FROM t " + "WHERE t.n < :n_2) " + "SELECT sum(ta.n) AS sum_1 FROM t AS ta" + ) + def test_recursive_union_no_alias_three(self): # like test one, but let's refer to the CTE # in a sibling CTE. @@ -187,6 +222,30 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT cte.x, bar.x FROM cte, bar" ) + def test_recursive_union_alias_three(self): + # like test one, but let's refer to the CTE + # in a sibling CTE. + + s1 = select([literal(0).label("x")]) + cte = s1.cte(name="cte", recursive=True) + + # can't do it here... + # bar = select([cte]).cte('bar') + cte = cte.union_all( + select([cte.c.x + 1]).where(cte.c.x < 10) + ).alias("cs1") + bar = select([cte]).cte('bar').alias("cs2") + + s2 = select([cte, bar]) + self.assert_compile(s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2), " + "bar AS (SELECT cs1.x AS x FROM cte AS cs1) " + "SELECT cs1.x, cs2.x FROM cte AS cs1, bar AS cs2" + ) + def test_recursive_union_no_alias_four(self): # like test one and three, but let's refer # previous version of "cte". here we test @@ -234,6 +293,53 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "FROM cte WHERE cte.x < :x_2) " "SELECT bar.x, cte.x FROM bar, cte") + def test_recursive_union_alias_four(self): + # like test one and three, but let's refer + # previous version of "cte". here we test + # how the compiler resolves multiple instances + # of "cte". + + s1 = select([literal(0).label("x")]) + cte = s1.cte(name="cte", recursive=True) + + bar = select([cte]).cte('bar').alias("cs1") + cte = cte.union_all( + select([cte.c.x + 1]).where(cte.c.x < 10) + ).alias("cs2") + + # outer cte rendered first, then bar, which + # includes "inner" cte + s2 = select([cte, bar]) + self.assert_compile(s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2), " + "bar AS (SELECT cte.x AS x FROM cte) " + "SELECT cs2.x, cs1.x FROM cte AS cs2, bar AS cs1" + ) + + # bar rendered, only includes "inner" cte, + # "outer" cte isn't present + s2 = select([bar]) + self.assert_compile(s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x), " + "bar AS (SELECT cte.x AS x FROM cte) " + "SELECT cs1.x FROM bar AS cs1" + ) + + # bar rendered, but then the "outer" + # cte is rendered. + s2 = select([bar, cte]) + self.assert_compile( + s2, "WITH RECURSIVE bar AS (SELECT cte.x AS x FROM cte), " + "cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2) " + "SELECT cs1.x, cs2.x FROM bar AS cs1, cte AS cs2") + def test_conflicting_names(self): """test a flat out name conflict.""" @@ -290,6 +396,46 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "FROM regional_sales WHERE " "regional_sales.amount < :amount_2") + def test_union_cte_aliases(self): + orders = table('orders', + column('region'), + column('amount'), + ) + + regional_sales = select([ + orders.c.region, + orders.c.amount + ]).cte("regional_sales").alias("rs") + + s = select( + [regional_sales.c.region]).where( + regional_sales.c.amount > 500 + ) + + self.assert_compile(s, + "WITH regional_sales AS " + "(SELECT orders.region AS region, " + "orders.amount AS amount FROM orders) " + "SELECT rs.region " + "FROM regional_sales AS rs WHERE " + "rs.amount > :amount_1") + + s = s.union_all( + select([regional_sales.c.region]). + where( + regional_sales.c.amount < 300 + ) + ) + self.assert_compile(s, + "WITH regional_sales AS " + "(SELECT orders.region AS region, " + "orders.amount AS amount FROM orders) " + "SELECT rs.region FROM regional_sales AS rs " + "WHERE rs.amount > :amount_1 " + "UNION ALL SELECT rs.region " + "FROM regional_sales AS rs WHERE " + "rs.amount < :amount_2") + def test_reserved_quote(self): orders = table('orders', column('order'), @@ -319,6 +465,60 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): '(SELECT "CTE".id AS id FROM "CTE") AS anon_2' ) + def test_multi_subq_alias(self): + cte = select([literal(1).label("id")]).cte(name='cte1').alias("aa") + + s1 = select([cte.c.id]).alias() + s2 = select([cte.c.id]).alias() + + s = select([s1, s2]) + self.assert_compile( + s, + "WITH cte1 AS (SELECT :param_1 AS id) " + "SELECT anon_1.id, anon_2.id FROM " + "(SELECT aa.id AS id FROM cte1 AS aa) AS anon_1, " + "(SELECT aa.id AS id FROM cte1 AS aa) AS anon_2" + ) + + def test_cte_refers_to_aliased_cte_twice(self): + # test issue #4204 + a = table('a', column('id')) + b = table('b', column('id'), column('fid')) + c = table('c', column('id'), column('fid')) + + cte1 = ( + select([a.c.id]) + .cte(name='cte1') + ) + + aa = cte1.alias('aa') + + cte2 = ( + select([b.c.id]) + .select_from(b.join(aa, b.c.fid == aa.c.id)) + .cte(name='cte2') + ) + + cte3 = ( + select([c.c.id]) + .select_from(c.join(aa, c.c.fid == aa.c.id)) + .cte(name='cte3') + ) + + stmt = ( + select([cte3.c.id, cte2.c.id]) + .select_from(cte2.join(cte3, cte2.c.id == cte3.c.id)) + ) + self.assert_compile( + stmt, + "WITH cte1 AS (SELECT a.id AS id FROM a), " + "cte2 AS (SELECT b.id AS id FROM b " + "JOIN cte1 AS aa ON b.fid = aa.id), " + "cte3 AS (SELECT c.id AS id FROM c " + "JOIN cte1 AS aa ON c.fid = aa.id) " + "SELECT cte3.id, cte2.id FROM cte2 JOIN cte3 ON cte2.id = cte3.id" + ) + def test_named_alias_no_quote(self): cte = select([literal(1).label("id")]).cte(name='CTE')