From bef67e58121704a9836e1e5ec2d361cd2086036c Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 23 Feb 2022 12:50:36 -0500 Subject: [PATCH] support add_cte() for TextualSelect Fixed issue where the :meth:`.HasCTE.add_cte` method as called upon a :class:`.TextualSelect` instance was not being accommodated by the SQL compiler. The fix additionally adds more "SELECT"-like compiler behavior to :class:`.TextualSelect` including that DML CTEs such as UPDATE and INSERT may be accommodated. Fixes: #7760 Change-Id: Id97062d882e9b2a81b8e31c2bfaa9cfc5f77d5c1 --- doc/build/changelog/unreleased_14/7760.rst | 9 +++ lib/sqlalchemy/orm/context.py | 64 +++++++++++++++++++++- lib/sqlalchemy/sql/compiler.py | 18 +++++- test/sql/test_cte.py | 45 +++++++++++++++ 4 files changed, 133 insertions(+), 3 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/7760.rst diff --git a/doc/build/changelog/unreleased_14/7760.rst b/doc/build/changelog/unreleased_14/7760.rst new file mode 100644 index 0000000000..2f0d403dd8 --- /dev/null +++ b/doc/build/changelog/unreleased_14/7760.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, sql + :tickets: 7760 + + Fixed issue where the :meth:`.HasCTE.add_cte` method as called upon a + :class:`.TextualSelect` instance was not being accommodated by the SQL + compiler. The fix additionally adds more "SELECT"-like compiler behavior to + :class:`.TextualSelect` including that DML CTEs such as UPDATE and INSERT + may be accommodated. diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index f51abde0c3..63ed10d501 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -27,6 +27,7 @@ from .. import future from .. import inspect from .. import sql from .. import util +from ..sql import ClauseElement from ..sql import coercions from ..sql import expression from ..sql import roles @@ -486,8 +487,8 @@ class ORMFromStatementCompileState(ORMCompileState): entity.setup_compile_state(self) # we did the setup just to get primary columns. - self.statement = expression.TextualSelect( - self.statement, self.primary_columns, positional=False + self.statement = _AdHocColumnsStatement( + self.statement, self.primary_columns ) else: # allow TextualSelect with implicit columns as well @@ -514,6 +515,65 @@ class ORMFromStatementCompileState(ORMCompileState): return None +class _AdHocColumnsStatement(ClauseElement): + """internal object created to somewhat act like a SELECT when we + are selecting columns from a DML RETURNING. + + + """ + + __visit_name__ = None + + def __init__(self, text, columns): + self.element = text + self.column_args = [ + coercions.expect(roles.ColumnsClauseRole, c) for c in columns + ] + + def _generate_cache_key(self): + raise NotImplementedError() + + def _gen_cache_key(self, anon_map, bindparams): + raise NotImplementedError() + + def _compiler_dispatch( + self, compiler, compound_index=None, asfrom=False, **kw + ): + """provide a fixed _compiler_dispatch method.""" + + toplevel = not compiler.stack + entry = ( + compiler._default_stack_entry if toplevel else compiler.stack[-1] + ) + + populate_result_map = ( + toplevel + # these two might not be needed + or ( + compound_index == 0 + and entry.get("need_result_map_for_compound", False) + ) + or entry.get("need_result_map_for_nested", False) + ) + + if populate_result_map: + compiler._ordered_columns = ( + compiler._textual_ordered_columns + ) = False + + # enable looser result column matching. this is shown to be + # needed by test_query.py::TextTest + compiler._loose_column_name_matching = True + + for c in self.column_args: + compiler.process( + c, + within_columns_clause=True, + add_to_result_map=compiler._add_to_result_map, + ) + return compiler.process(self.element, **kw) + + @sql.base.CompileState.plugin_for("orm", "select") class ORMSelectCompileState(ORMCompileState, SelectState): _already_joined_edges = () diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 4a169f719d..b140f92975 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1596,6 +1596,17 @@ class SQLCompiler(Compiled): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] + new_entry = { + "correlate_froms": set(), + "asfrom_froms": set(), + "selectable": taf, + } + self.stack.append(new_entry) + + if taf._independent_ctes: + for cte in taf._independent_ctes: + cte._compiler_dispatch(self, **kw) + populate_result_map = ( toplevel or ( @@ -1623,7 +1634,12 @@ class SQLCompiler(Compiled): add_to_result_map=self._add_to_result_map, ) - return self.process(taf.element, **kw) + text = self.process(taf.element, **kw) + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text + + return text def visit_null(self, expr, **kw): return "NULL" diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 64479b9692..b056925048 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1601,6 +1601,51 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): }, ) + def test_textual_select_uses_independent_cte_one(self): + """test #7760""" + products = table("products", column("id"), column("price")) + + upd_cte = ( + products.update().values(price=10).where(products.c.price > 50) + ).cte() + + stmt = ( + text( + "SELECT products.id, products.price " + "FROM products WHERE products.price < :price_2" + ) + .columns(products.c.id, products.c.price) + .bindparams(price_2=45) + .add_cte(upd_cte) + ) + + self.assert_compile( + stmt, + "WITH anon_1 AS (UPDATE products SET price=:param_1 " + "WHERE products.price > :price_1) " + "SELECT products.id, products.price " + "FROM products WHERE products.price < :price_2", + checkparams={"param_1": 10, "price_1": 50, "price_2": 45}, + ) + + def test_textual_select_uses_independent_cte_two(self): + + foo = table("foo", column("id")) + bar = table("bar", column("id"), column("attr"), column("foo_id")) + s1 = select(foo.c.id) + s2 = text( + "SELECT bar.id, bar.attr FROM bar " + "WHERE bar.foo_id IN (SELECT id FROM baz)" + ).columns(bar.c.id, bar.c.attr) + s3 = s2.add_cte(s1.cte(name="baz")) + + self.assert_compile( + s3, + "WITH baz AS (SELECT foo.id AS id FROM foo) " + "SELECT bar.id, bar.attr FROM bar WHERE bar.foo_id IN " + "(SELECT id FROM baz)", + ) + def test_insert_uses_independent_cte(self): products = table("products", column("id"), column("price")) -- 2.47.2