--- /dev/null
+.. change::
+ :tags: bug, sql
+ :tickets: 4123
+
+ Repaired issue in new :paramref:`_sql.HasCTE.cte.nesting` parameter
+ introduced with :ticket:`4123` where a recursive :class:`_sql.CTE` using
+ :paramref:`_sql.HasCTE.cte.recursive` in typical conjunction with UNION
+ would not compile correctly. Additionally makes some adjustments so that
+ the :class:`_sql.CTE` construct creates a correct cache key.
+ Pull request courtesy Eric Masseran.
"""
# collect CTEs to tack on top of a SELECT
+ # To store the query to print - Dict[cte, text_query]
self.ctes = util.OrderedDict()
- # Detect same CTE references
- self.ctes_by_name = {}
- self.level_by_ctes = {}
+
+ # Detect same CTE references - Dict[(level, name), cte]
+ # Level is required for supporting nesting
+ self.ctes_by_level_name = {}
+
+ # To retrieve key/level in ctes_by_level_name -
+ # Dict[cte_reference, (level, cte_name)]
+ self.level_name_by_cte = {}
+
self.ctes_recursive = False
if self.positional:
self.cte_positional = {}
):
self._init_cte_state()
- cte_level = len(self.stack) if cte.nesting else 1
-
kwargs["visiting_cte"] = cte
cte_name = cte.name
is_new_cte = True
embedded_in_current_named_cte = False
- if cte in self.level_by_ctes:
- cte_level = self.level_by_ctes[cte]
+ _reference_cte = cte._get_reference_cte()
+
+ if _reference_cte in self.level_name_by_cte:
+ cte_level, _ = self.level_name_by_cte[_reference_cte]
+ assert _ == cte_name
+ else:
+ cte_level = len(self.stack) if cte.nesting else 1
cte_level_name = (cte_level, cte_name)
- if cte_level_name in self.ctes_by_name:
- existing_cte = self.ctes_by_name[cte_level_name]
+ if cte_level_name in self.ctes_by_level_name:
+ existing_cte = self.ctes_by_level_name[cte_level_name]
embedded_in_current_named_cte = visiting_cte is existing_cte
# we've generated a same-named CTE that we are enclosed in,
# or this is the same CTE. just return the name.
- if cte in existing_cte._restates or cte is existing_cte:
+ if cte is existing_cte._restates or cte is existing_cte:
is_new_cte = False
- elif existing_cte in cte._restates:
+ elif existing_cte is cte._restates:
# we've generated a same-named CTE that is
# enclosed in us - we take precedence, so
# discard the text for the "inner".
del self.ctes[existing_cte]
- del self.level_by_ctes[existing_cte]
+
+ existing_cte_reference_cte = existing_cte._get_reference_cte()
+
+ # TODO: determine if these assertions are correct. they
+ # pass for current test cases
+ # assert existing_cte_reference_cte is _reference_cte
+ # assert existing_cte_reference_cte is existing_cte
+
+ del self.level_name_by_cte[existing_cte_reference_cte]
else:
raise exc.CompileError(
"Multiple, unrelated CTEs found with "
"the same name: %r" % cte_name
)
- if asfrom or is_new_cte:
- if cte._cte_alias is not None:
- pre_alias_cte = cte._cte_alias
- cte_pre_alias_name = cte._cte_alias.name
- if isinstance(cte_pre_alias_name, elements._truncated_label):
- cte_pre_alias_name = self._truncated_identifier(
- "alias", cte_pre_alias_name
- )
- else:
- pre_alias_cte = cte
- cte_pre_alias_name = None
+ if not asfrom and not is_new_cte:
+ return None
+
+ if cte._cte_alias is not None:
+ pre_alias_cte = cte._cte_alias
+ cte_pre_alias_name = cte._cte_alias.name
+ if isinstance(cte_pre_alias_name, elements._truncated_label):
+ cte_pre_alias_name = self._truncated_identifier(
+ "alias", cte_pre_alias_name
+ )
+ else:
+ pre_alias_cte = cte
+ cte_pre_alias_name = None
if is_new_cte:
- self.ctes_by_name[cte_level_name] = cte
+ self.ctes_by_level_name[cte_level_name] = cte
+ self.level_name_by_cte[_reference_cte] = cte_level_name
if (
"autocommit" in cte.element._execution_options
)
self.ctes[cte] = text
- self.level_by_ctes[cte] = cte_level
if asfrom:
if from_linter:
if nesting_level and nesting_level > 1:
ctes = util.OrderedDict()
for cte in list(self.ctes.keys()):
- cte_level = self.level_by_ctes[cte]
+ cte_level, cte_name = self.level_name_by_cte[
+ cte._get_reference_cte()
+ ]
is_rendered_level = cte_level == nesting_level or (
include_following_stack and cte_level == nesting_level + 1
)
ctes[cte] = self.ctes[cte]
- del self.ctes[cte]
- del self.level_by_ctes[cte]
-
- cte_name = cte.name
- if isinstance(cte_name, elements._truncated_label):
- cte_name = self._truncated_identifier("alias", cte_name)
-
- del self.ctes_by_name[(cte_level, cte_name)]
else:
ctes = self.ctes
cte_text = self.get_cte_preamble(ctes_recursive) + " "
cte_text += ", \n".join([txt for txt in ctes.values()])
cte_text += "\n "
+
+ if nesting_level and nesting_level > 1:
+ for cte in list(ctes.keys()):
+ cte_level, cte_name = self.level_name_by_cte[
+ cte._get_reference_cte()
+ ]
+ del self.ctes[cte]
+ del self.ctes_by_level_name[(cte_level, cte_name)]
+ del self.level_name_by_cte[cte._get_reference_cte()]
+
return cte_text
def get_cte_preamble(self, recursive):
AliasedReturnsRows._traverse_internals
+ [
("_cte_alias", InternalTraversal.dp_clauseelement),
- ("_restates", InternalTraversal.dp_clauseelement_list),
+ ("_restates", InternalTraversal.dp_clauseelement),
("recursive", InternalTraversal.dp_boolean),
+ ("nesting", InternalTraversal.dp_boolean),
]
+ HasPrefixes._has_prefixes_traverse_internals
+ HasSuffixes._has_suffixes_traverse_internals
recursive=False,
nesting=False,
_cte_alias=None,
- _restates=(),
+ _restates=None,
_prefixes=None,
_suffixes=None,
):
self.recursive = recursive
self.nesting = nesting
self._cte_alias = _cte_alias
+ # Keep recursivity reference with union/union_all
self._restates = _restates
if _prefixes:
self._prefixes = _prefixes
name=self.name,
recursive=self.recursive,
nesting=self.nesting,
- _restates=self._restates + (self,),
+ _restates=self,
_prefixes=self._prefixes,
_suffixes=self._suffixes,
)
name=self.name,
recursive=self.recursive,
nesting=self.nesting,
- _restates=self._restates + (self,),
+ _restates=self,
_prefixes=self._prefixes,
_suffixes=self._suffixes,
)
+ def _get_reference_cte(self):
+ """
+ A recursive CTE is updated to attach the recursive part.
+ Updated CTEs should still refer to the original CTE.
+ This function returns this reference identifier.
+ """
+ return self._restates if self._restates is not None else self
+
class HasCTE(roles.HasCTERole):
"""Mixin that declares a class to include CTE support.
),
lambda: (
select(table_a.c.a).cte(),
+ select(table_a.c.a).cte(nesting=True),
select(table_a.c.a).cte(recursive=True),
select(table_a.c.a).cte(name="some_cte", recursive=True),
select(table_a.c.a).cte(name="some_cte"),
)
return stmt
- return [one(), one_diff(), two(), three()]
+ def four():
+ stmt = select(table_a.c.a).cte(recursive=True)
+ stmt = stmt.union(select(stmt.c.a + 1).where(stmt.c.a < 10))
+ return stmt
+
+ def five():
+ stmt = select(table_a.c.a).cte(recursive=True, nesting=True)
+ stmt = stmt.union(select(stmt.c.a + 1).where(stmt.c.a < 10))
+ return stmt
+
+ return [one(), one_diff(), two(), three(), four(), five()]
fixtures.append(_complex_fixtures)
cte = s1.cte(name="cte", recursive=True)
bar = select(cte).cte("bar").alias("cs1")
+
cte = cte.union_all(select(cte.c.x + 1).where(cte.c.x < 10)).alias(
"cs2"
)
"SELECT cte.outer_cte FROM cte",
)
+ def test_select_with_aliased_nesting_cte_in_cte(self):
+ nesting_cte = (
+ select(literal(1).label("inner_cte"))
+ .cte("nesting", nesting=True)
+ .alias("aliased_nested")
+ )
+ stmt = select(
+ select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte")
+ )
+
+ self.assert_compile(
+ stmt,
+ "WITH cte AS (WITH nesting AS (SELECT :param_1 AS inner_cte) "
+ "SELECT aliased_nested.inner_cte AS outer_cte "
+ "FROM nesting AS aliased_nested) "
+ "SELECT cte.outer_cte FROM cte",
+ )
+
def test_nesting_cte_in_cte_with_same_name(self):
nesting_cte = select(literal(1).label("inner_cte")).cte(
"some_cte", nesting=True
nesting_cte = select(literal(1).label("inner_cte")).cte(
"nesting", nesting=True
)
- stmt = select(
- select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
- "cte", recursive=True
- )
+
+ rec_cte = select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
+ "rec_cte", recursive=True
)
+ rec_part = select(rec_cte.c.outer_cte).where(
+ rec_cte.c.outer_cte == literal(1)
+ )
+ rec_cte = rec_cte.union(rec_part)
+
+ stmt = select(rec_cte)
self.assert_compile(
stmt,
- "WITH RECURSIVE cte(outer_cte) AS (WITH nesting AS "
+ "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS "
"(SELECT :param_1 AS inner_cte) "
- "SELECT nesting.inner_cte AS outer_cte FROM nesting) "
- "SELECT cte.outer_cte FROM cte",
+ "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
+ "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
+ "WHERE rec_cte.outer_cte = :param_2) "
+ "SELECT rec_cte.outer_cte FROM rec_cte",
)
def test_recursive_nesting_cte_in_cte(self):
- nesting_cte = select(literal(1).label("inner_cte")).cte(
- "nesting", nesting=True, recursive=True
+ rec_root = select(literal(1).label("inner_cte")).cte(
+ "nesting", recursive=True, nesting=True
+ )
+ rec_part = select(rec_root.c.inner_cte).where(
+ rec_root.c.inner_cte == literal(1)
)
+ nesting_cte = rec_root.union(rec_part)
+
stmt = select(
select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte")
)
self.assert_compile(
stmt,
"WITH cte AS (WITH RECURSIVE nesting(inner_cte) AS "
- "(SELECT :param_1 AS inner_cte) "
+ "(SELECT :param_1 AS inner_cte UNION "
+ "SELECT nesting.inner_cte AS inner_cte FROM nesting "
+ "WHERE nesting.inner_cte = :param_2) "
"SELECT nesting.inner_cte AS outer_cte FROM nesting) "
"SELECT cte.outer_cte FROM cte",
)
+ def test_anon_recursive_nesting_cte_in_cte(self):
+ rec_root = (
+ select(literal(1).label("inner_cte"))
+ .cte("nesting", recursive=True, nesting=True)
+ .alias()
+ )
+ rec_part = select(rec_root.c.inner_cte).where(
+ rec_root.c.inner_cte == literal(1)
+ )
+ nesting_cte = rec_root.union(rec_part)
+
+ stmt = select(
+ select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte")
+ )
+
+ self.assert_compile(
+ stmt,
+ "WITH cte AS (WITH RECURSIVE anon_1(inner_cte) AS "
+ "(SELECT :param_1 AS inner_cte UNION "
+ "SELECT anon_1.inner_cte AS inner_cte FROM anon_1 "
+ "WHERE anon_1.inner_cte = :param_2) "
+ "SELECT anon_1.inner_cte AS outer_cte FROM anon_1) "
+ "SELECT cte.outer_cte FROM cte",
+ )
+
+ def test_fully_aliased_recursive_nesting_cte_in_cte(self):
+ rec_root = (
+ select(literal(1).label("inner_cte"))
+ .cte("nesting", recursive=True, nesting=True)
+ .alias("aliased_nesting")
+ )
+ rec_part = select(rec_root.c.inner_cte).where(
+ rec_root.c.inner_cte == literal(1)
+ )
+ nesting_cte = rec_root.union(rec_part)
+
+ stmt = select(
+ select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte")
+ )
+
+ self.assert_compile(
+ stmt,
+ "WITH cte AS (WITH RECURSIVE aliased_nesting(inner_cte) AS "
+ "(SELECT :param_1 AS inner_cte UNION "
+ "SELECT aliased_nesting.inner_cte AS inner_cte "
+ "FROM aliased_nesting "
+ "WHERE aliased_nesting.inner_cte = :param_2) "
+ "SELECT aliased_nesting.inner_cte AS outer_cte "
+ "FROM aliased_nesting) "
+ "SELECT cte.outer_cte FROM cte",
+ )
+
+ def test_aliased_recursive_nesting_cte_in_cte(self):
+ rec_root = select(literal(1).label("inner_cte")).cte(
+ "nesting", recursive=True, nesting=True
+ )
+ rec_part = select(rec_root.c.inner_cte).where(
+ rec_root.c.inner_cte == literal(1)
+ )
+ nesting_cte = rec_root.union(rec_part).alias("aliased_nesting")
+
+ stmt = select(
+ select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte")
+ )
+
+ self.assert_compile(
+ stmt,
+ "WITH cte AS (WITH RECURSIVE nesting(inner_cte) AS "
+ "(SELECT :param_1 AS inner_cte UNION "
+ "SELECT nesting.inner_cte AS inner_cte FROM nesting "
+ "WHERE nesting.inner_cte = :param_2) "
+ "SELECT aliased_nesting.inner_cte AS outer_cte "
+ "FROM nesting AS aliased_nesting) "
+ "SELECT cte.outer_cte FROM cte",
+ )
+
def test_same_nested_cte_is_not_generated_twice(self):
# Same = name and query
nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte(
nesting_cte = select(literal(1).label("inner_cte")).cte(
"nesting", nesting=True, recursive=True
)
- stmt = select(
- select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
- "cte", recursive=True
- )
+ nesting_rec_part = select(nesting_cte.c.inner_cte).where(
+ nesting_cte.c.inner_cte == literal(1)
+ )
+ nesting_cte = nesting_cte.union(nesting_rec_part)
+
+ rec_cte = select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
+ "rec_cte", recursive=True
)
+ rec_part = select(rec_cte.c.outer_cte).where(
+ rec_cte.c.outer_cte == literal(1)
+ )
+ rec_cte = rec_cte.union(rec_part)
+
+ stmt = select(rec_cte)
self.assert_compile(
stmt,
- "WITH RECURSIVE cte(outer_cte) AS "
- "(WITH RECURSIVE nesting(inner_cte) "
- "AS (SELECT :param_1 AS inner_cte) "
- "SELECT nesting.inner_cte AS outer_cte FROM nesting) "
- "SELECT cte.outer_cte FROM cte",
+ "WITH RECURSIVE rec_cte(outer_cte) AS ("
+ "WITH RECURSIVE nesting(inner_cte) AS "
+ "(SELECT :param_1 AS inner_cte UNION "
+ "SELECT nesting.inner_cte AS inner_cte FROM nesting "
+ "WHERE nesting.inner_cte = :param_2) "
+ "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
+ "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
+ "WHERE rec_cte.outer_cte = :param_3) "
+ "SELECT rec_cte.outer_cte FROM rec_cte",
)
def test_select_from_insert_cte_with_nesting(self):