"""
# collect CTEs to tack on top of a SELECT
# Dict[cte_id, text_query]
+ # To remember the query to print
self.ctes = util.OrderedDict()
# Detect same CTE references
- # Dict[cte_id, cte_instance]
+ # Dict[(level, name), cte_instance]
self.ctes_by_name = {}
+ # Dict[cte_id, cte_instance]
+ # Useful
+ self.ctes_by_id = {}
+ # Dict[cte_id, cte_instance]
+ # To retrieve ctes_by_name key
+ self.names_by_id = {}
# Dict[cte_id, level]
+ # Remember level for nesting usage
self.level_by_ctes = {}
self.ctes_recursive = False
if self.positional:
is_new_cte = True
embedded_in_current_named_cte = False
- # if cte.unique_id in self.level_by_ctes:
- # cte_level = self.level_by_ctes[cte.unique_id]
+ if cte.unique_id in self.level_by_ctes:
+ cte_level = self.level_by_ctes[cte.unique_id]
- cte_level_name = cte.unique_id
+ cte_level_name = (cte_level, cte_name)
if cte_level_name in self.ctes_by_name:
existing_cte = self.ctes_by_name[cte_level_name]
embedded_in_current_named_cte = visiting_cte is existing_cte
# we've generated a same-named CTE that is
# enclosed in us - we take precedence, so
# discard the text for the "inner".
- del self.ctes[existing_cte.unique_id]
+ del self.ctes[existing_cte]
del self.level_by_ctes[existing_cte.unique_id]
else:
raise exc.CompileError(
if is_new_cte:
self.ctes_by_name[cte_level_name] = cte
- # TODO:
- # self.level_by_ctes[cte.unique_id] = cte_level
+ self.names_by_id[cte.unique_id] = cte_name
+ self.level_by_ctes[cte.unique_id] = cte_level
if (
"autocommit" in cte.element._execution_options
)
if self.positional:
- kwargs["positional_names"] = self.cte_positional[
- cte.unique_id
- ] = []
+ kwargs["positional_names"] = self.cte_positional[cte] = []
assert kwargs.get("subquery", False) is False
cte, cte._suffixes, **kwargs
)
- self.ctes[cte.unique_id] = text
+ self.ctes[cte] = text
self.level_by_ctes[cte.unique_id] = cte_level
if asfrom:
if nesting_level and nesting_level > 1:
ctes = util.OrderedDict()
- for cte_id in list(self.ctes.keys()):
- cte = self.ctes_by_name[cte_id]
- cte_level = self.level_by_ctes[cte_id]
+ for cte in list(self.ctes.keys()):
+ cte_level = self.level_by_ctes[cte.unique_id]
+ cte_name = self.names_by_id[cte.unique_id]
is_rendered_level = cte_level == nesting_level or (
include_following_stack and cte_level == nesting_level + 1
)
if not (cte.nesting and is_rendered_level):
continue
- ctes[cte_id] = self.ctes[cte_id]
+ ctes[cte] = self.ctes[cte]
else:
ctes = self.ctes
if not ctes:
return ""
- ctes_recursive = any(
- [self.ctes_by_name[cte_id].recursive for cte_id in ctes]
- )
+ ctes_recursive = any([cte.recursive for cte in ctes])
if self.positional:
self.positiontup = (
- sum([self.cte_positional[cte_id] for cte_id in ctes], [])
+ sum([self.cte_positional[cte] for cte in ctes], [])
+ self.positiontup
)
cte_text = self.get_cte_preamble(ctes_recursive) + " "
cte_text += ", \n".join([txt for txt in ctes.values()])
cte_text += "\n "
- for cte_id in list(ctes.keys()):
- del self.ctes[cte_id]
- del self.level_by_ctes[cte_id]
- del self.ctes_by_name[cte_id]
+ if nesting_level and nesting_level > 1:
+ for cte in list(ctes.keys()):
+ cte_level = self.level_by_ctes[cte.unique_id]
+ cte_name = self.names_by_id[cte.unique_id]
+ del self.ctes[cte]
+ del self.level_by_ctes[cte.unique_id]
+ del self.names_by_id[cte.unique_id]
+ del self.ctes_by_name[(cte_level, cte_name)]
return cte_text
" SELECT recursive_cte.the_value FROM recursive_cte",
)
- # def test_recursive_cte_w_union_aliased(self):
- # nesting_cte = select(literal(1).label("inner_cte")).cte(
- # "nesting", recursive=True, nesting=True
- # )
- # nesting_cte_a = nesting_cte.alias()
- # nesting_cte = nesting_cte.union(
- # select(nesting_cte_a.c.inner_cte).where(
- # nesting_cte_a.c.inner_cte == literal(1)
- # )
- # )
-
- # stmt = select(nesting_cte.c.inner_cte)
- # self.assert_compile(
- # stmt,
- # "WITH RECURSIVE nesting(inner_cte) AS "
- # "(SELECT :param_1 AS inner_cte UNION "
- # "SELECT anon_1.inner_cte AS inner_cte FROM nesting AS anon_1 "
- # "WHERE anon_1.inner_cte = :param_2) "
- # "SELECT nesting.inner_cte FROM nesting",
- # )
+ def test_recursive_cte_w_union_aliased(self):
+ nesting_cte = select(literal(1).label("inner_cte")).cte(
+ "nesting", recursive=True, nesting=True
+ )
+ nesting_cte_a = nesting_cte.alias()
+ nesting_cte = nesting_cte.union(
+ select(nesting_cte_a.c.inner_cte).where(
+ nesting_cte_a.c.inner_cte == literal(1)
+ )
+ )
+
+ stmt = select(nesting_cte.c.inner_cte)
+ self.assert_compile(
+ stmt,
+ "WITH RECURSIVE nesting(inner_cte) AS "
+ "(SELECT :param_1 AS inner_cte UNION "
+ "SELECT anon_1.inner_cte AS inner_cte FROM nesting AS anon_1 "
+ "WHERE anon_1.inner_cte = :param_2) "
+ "SELECT nesting.inner_cte FROM nesting",
+ )
def test_recursive_cte_w_union(self):
nesting_cte = select(literal(1).label("inner_cte")).cte(