--- /dev/null
+.. change::
+ :tags: bug, sql
+ :tickets: 7760
+
+ Fixed issue where the :meth:`.HasCTE.add_cte` method as called upon a
+ :class:`.TextualSelect` instance was not being accommodated by the SQL
+ compiler. The fix additionally adds more "SELECT"-like compiler behavior to
+ :class:`.TextualSelect` including that DML CTEs such as UPDATE and INSERT
+ may be accommodated.
from .. import inspect
from .. import sql
from .. import util
+from ..sql import ClauseElement
from ..sql import coercions
from ..sql import expression
from ..sql import roles
entity.setup_compile_state(self)
# we did the setup just to get primary columns.
- self.statement = expression.TextualSelect(
- self.statement, self.primary_columns, positional=False
+ self.statement = _AdHocColumnsStatement(
+ self.statement, self.primary_columns
)
else:
# allow TextualSelect with implicit columns as well
return None
+class _AdHocColumnsStatement(ClauseElement):
+ """internal object created to somewhat act like a SELECT when we
+ are selecting columns from a DML RETURNING.
+
+
+ """
+
+ __visit_name__ = None
+
+ def __init__(self, text, columns):
+ self.element = text
+ self.column_args = [
+ coercions.expect(roles.ColumnsClauseRole, c) for c in columns
+ ]
+
+ def _generate_cache_key(self):
+ raise NotImplementedError()
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ raise NotImplementedError()
+
+ def _compiler_dispatch(
+ self, compiler, compound_index=None, asfrom=False, **kw
+ ):
+ """provide a fixed _compiler_dispatch method."""
+
+ toplevel = not compiler.stack
+ entry = (
+ compiler._default_stack_entry if toplevel else compiler.stack[-1]
+ )
+
+ populate_result_map = (
+ toplevel
+ # these two might not be needed
+ or (
+ compound_index == 0
+ and entry.get("need_result_map_for_compound", False)
+ )
+ or entry.get("need_result_map_for_nested", False)
+ )
+
+ if populate_result_map:
+ compiler._ordered_columns = (
+ compiler._textual_ordered_columns
+ ) = False
+
+ # enable looser result column matching. this is shown to be
+ # needed by test_query.py::TextTest
+ compiler._loose_column_name_matching = True
+
+ for c in self.column_args:
+ compiler.process(
+ c,
+ within_columns_clause=True,
+ add_to_result_map=compiler._add_to_result_map,
+ )
+ return compiler.process(self.element, **kw)
+
+
@sql.base.CompileState.plugin_for("orm", "select")
class ORMSelectCompileState(ORMCompileState, SelectState):
_already_joined_edges = ()
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
+ new_entry = {
+ "correlate_froms": set(),
+ "asfrom_froms": set(),
+ "selectable": taf,
+ }
+ self.stack.append(new_entry)
+
+ if taf._independent_ctes:
+ for cte in taf._independent_ctes:
+ cte._compiler_dispatch(self, **kw)
+
populate_result_map = (
toplevel
or (
add_to_result_map=self._add_to_result_map,
)
- return self.process(taf.element, **kw)
+ text = self.process(taf.element, **kw)
+ if self.ctes:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = self._render_cte_clause(nesting_level=nesting_level) + text
+
+ return text
def visit_null(self, expr, **kw):
return "NULL"
},
)
+ def test_textual_select_uses_independent_cte_one(self):
+ """test #7760"""
+ products = table("products", column("id"), column("price"))
+
+ upd_cte = (
+ products.update().values(price=10).where(products.c.price > 50)
+ ).cte()
+
+ stmt = (
+ text(
+ "SELECT products.id, products.price "
+ "FROM products WHERE products.price < :price_2"
+ )
+ .columns(products.c.id, products.c.price)
+ .bindparams(price_2=45)
+ .add_cte(upd_cte)
+ )
+
+ self.assert_compile(
+ stmt,
+ "WITH anon_1 AS (UPDATE products SET price=:param_1 "
+ "WHERE products.price > :price_1) "
+ "SELECT products.id, products.price "
+ "FROM products WHERE products.price < :price_2",
+ checkparams={"param_1": 10, "price_1": 50, "price_2": 45},
+ )
+
+ def test_textual_select_uses_independent_cte_two(self):
+
+ foo = table("foo", column("id"))
+ bar = table("bar", column("id"), column("attr"), column("foo_id"))
+ s1 = select(foo.c.id)
+ s2 = text(
+ "SELECT bar.id, bar.attr FROM bar "
+ "WHERE bar.foo_id IN (SELECT id FROM baz)"
+ ).columns(bar.c.id, bar.c.attr)
+ s3 = s2.add_cte(s1.cte(name="baz"))
+
+ self.assert_compile(
+ s3,
+ "WITH baz AS (SELECT foo.id AS id FROM foo) "
+ "SELECT bar.id, bar.attr FROM bar WHERE bar.foo_id IN "
+ "(SELECT id FROM baz)",
+ )
+
def test_insert_uses_independent_cte(self):
products = table("products", column("id"), column("price"))