text += self._row_limit_clause(cs, **kwargs)
if self.ctes:
- # Nesting CTEs from deeper select
- nesting_level = (len(self.stack) + 1) if not toplevel else None
- text = self._render_cte_clause(nesting_level=nesting_level) + text
+ nesting_level = len(self.stack) if not toplevel else None
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level, include_following_stack=True
+ )
+ + text
+ )
self.stack.pop(-1)
return text
else:
return self.bindtemplate % {"name": name}
+ def get_name(self, name):
+ if isinstance(name, elements._truncated_label):
+ return self._truncated_identifier("alias", name)
+ else:
+ return name
+
def visit_cte(
self,
cte,
cte_level = len(self.stack) if cte.nesting else 1
kwargs["visiting_cte"] = cte
- if isinstance(cte.name, elements._truncated_label):
- cte_name = self._truncated_identifier("alias", cte.name)
- else:
- cte_name = cte.name
+
+ cte_name = self.get_name(cte.name)
is_new_cte = True
embedded_in_current_named_cte = False
if toplevel and not self.compile_state:
self.compile_state = compile_state
+ is_embedded_select = compound_index is not None or insert_into
+
# translate step for Oracle, SQL Server which often need to
# restructure the SELECT to allow for LIMIT/OFFSET and possibly
# other conditions
if per_dialect:
text += " " + self.get_statement_hint_text(per_dialect)
- # In compound query, CTEs are shared at the compound level
- if self.ctes and compound_index is None and not insert_into:
- nesting_level = len(self.stack) if not toplevel else None
- text = self._render_cte_clause(nesting_level=nesting_level) + text
+ if self.ctes:
+ # In compound query, CTEs are shared at the compound level
+ if not is_embedded_select:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = (
+ self._render_cte_clause(nesting_level=nesting_level) + text
+ )
if select_stmt._suffixes:
text += " " + self._generate_prefixes(
def _render_cte_clause(
self,
nesting_level=None,
+ include_following_stack=False,
):
+ """
+ include_following_stack
+ Also render the nesting CTEs on the next stack. Useful for
+ SQL structures like UNION or INSERT that can wrap SELECT
+ statements containing nesting CTEs.
+ """
if not self.ctes:
return ""
ctes = {}
for cte in list(self.ctes.keys()):
cte_level = self.level_by_ctes[cte]
- if not (cte.nesting and cte_level == nesting_level):
+ is_rendered_level = cte_level == nesting_level or (
+ include_following_stack and cte_level == nesting_level + 1
+ )
+ if not (cte.nesting and is_rendered_level):
continue
ctes[cte] = self.ctes[cte]
del self.ctes[cte]
del self.level_by_ctes[cte]
+ cte_name = self.get_name(cte.name)
+ del self.ctes_by_name[(cte_level, cte_name)]
else:
ctes = self.ctes
)
if self.ctes and self.dialect.cte_follows_insert:
- nesting_level = (len(self.stack) + 1) if not toplevel else None
+ nesting_level = len(self.stack) if not toplevel else None
text += " %s%s" % (
- self._render_cte_clause(nesting_level=nesting_level),
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ include_following_stack=True,
+ ),
select_text,
)
else:
text += " " + returning_clause
if self.ctes and not self.dialect.cte_follows_insert:
- nesting_level = (len(self.stack) + 1) if not toplevel else None
- text = self._render_cte_clause(nesting_level=nesting_level) + text
+ nesting_level = len(self.stack) if not toplevel else None
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level, include_following_stack=True
+ )
+ + text
+ )
self.stack.pop(-1)
"table_1.price) SELECT delete_cte.id, delete_cte.price "
"FROM delete_cte",
)
+
+ def test_compound_select_with_nesting_cte_in_custom_order(self):
+ select_1_cte = select(literal(1).label("inner_cte")).cte(
+ "nesting_1", nesting=True
+ )
+ select_2_cte = select(literal(2).label("inner_cte")).cte(
+ "nesting_2", nesting=True
+ )
+
+ nesting_cte = (
+ select(select_1_cte)
+ .union(select(select_2_cte))
+ # Generate "select_2_cte" first
+ .add_cte(select_2_cte)
+ .subquery()
+ )
+
+ stmt = select(
+ select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte")
+ )
+
+ self.assert_compile(
+ stmt,
+ "WITH cte AS ("
+ "SELECT anon_1.inner_cte AS outer_cte FROM ("
+ "WITH nesting_2 AS (SELECT %(param_1)s AS inner_cte)"
+ ", nesting_1 AS (SELECT %(param_2)s AS inner_cte)"
+ " SELECT nesting_1.inner_cte AS inner_cte FROM nesting_1"
+ " UNION"
+ " SELECT nesting_2.inner_cte AS inner_cte FROM nesting_2"
+ ") AS anon_1"
+ ") SELECT cte.outer_cte FROM cte",
+ )
+
+ def test_recursive_cte_referenced_multiple_times_with_nesting_cte(self):
+ rec_root = select(literal(1).label("the_value")).cte(
+ "recursive_cte", recursive=True
+ )
+
+ # Allow to reference the recursive CTE more than once
+ rec_root_ref = rec_root.select().cte(
+ "allow_multiple_ref", nesting=True
+ )
+ should_continue = select(
+ exists(
+ select(rec_root_ref.c.the_value)
+ .where(rec_root_ref.c.the_value < 10)
+ .limit(1)
+ ).label("val")
+ ).cte("should_continue", nesting=True)
+
+ rec_part_1 = select(rec_root_ref.c.the_value * 2).where(
+ should_continue.c.val != True
+ )
+ rec_part_2 = select(rec_root_ref.c.the_value * 3).where(
+ should_continue.c.val != True
+ )
+
+ rec_part = rec_part_1.add_cte(rec_root_ref).union_all(rec_part_2)
+
+ rec_cte = rec_root.union_all(rec_part)
+
+ stmt = rec_cte.select()
+
+ self.assert_compile(
+ stmt,
+ "WITH RECURSIVE recursive_cte(the_value) AS ("
+ "SELECT %(param_1)s AS the_value UNION ALL ("
+ "WITH allow_multiple_ref AS ("
+ "SELECT recursive_cte.the_value AS the_value FROM recursive_cte)"
+ ", should_continue AS (SELECT EXISTS ("
+ "SELECT allow_multiple_ref.the_value FROM allow_multiple_ref"
+ " WHERE allow_multiple_ref.the_value < %(the_value_2)s"
+ " LIMIT %(param_2)s) AS val) "
+ "SELECT allow_multiple_ref.the_value * %(the_value_1)s AS anon_1"
+ " FROM allow_multiple_ref, should_continue WHERE should_continue.val != true"
+ " UNION ALL SELECT allow_multiple_ref.the_value * %(the_value_3)s"
+ " AS anon_2 FROM allow_multiple_ref, should_continue"
+ " WHERE should_continue.val != true))"
+ " SELECT recursive_cte.the_value FROM recursive_cte",
+ )