From 6d02de94b0c13d7e8f1182042a4c19581847542d Mon Sep 17 00:00:00 2001 From: Eric Masseran Date: Wed, 4 Aug 2021 16:41:42 +0200 Subject: [PATCH] Dispatch independent ctes on compound select + test --- lib/sqlalchemy/sql/compiler.py | 6 ++++++ test/sql/test_cte.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index a81507acb9..1e090ef103 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1787,6 +1787,8 @@ class SQLCompiler(Compiled): if toplevel and not self.compile_state: self.compile_state = compile_state + compound_stmt = compile_state.statement + entry = self._default_stack_entry if toplevel else self.stack[-1] need_result_map = toplevel or ( not compound_index @@ -1807,6 +1809,10 @@ class SQLCompiler(Compiled): } ) + if compound_stmt._independent_ctes: + for cte in compound_stmt._independent_ctes: + cte._compiler_dispatch(self, **kwargs) + keyword = self.compound_keywords.get(cs.keyword) text = (" " + keyword + " ").join( diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 77905cd896..cc663d53fb 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1533,6 +1533,37 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): checkparams={"param_1": 10, "price_1": 50, "price_2": 45}, ) + def test_compound_select_uses_independent_cte(self): + products = table("products", column("id"), column("price")) + + upd_cte = ( + products.update().values(price=10).where(products.c.price > 50) + ).cte() + + stmt = ( + products.select() + .where(products.c.price < 45) + .union(products.select().where(products.c.price > 90)) + .add_cte(upd_cte) + ) + + self.assert_compile( + stmt, + "WITH anon_1 AS (UPDATE products SET price=:param_1 " + "WHERE products.price > :price_1) " + "SELECT products.id, products.price " + "FROM products WHERE products.price < :price_2 " + "UNION " + "SELECT products.id, products.price " + "FROM products WHERE products.price > :price_3", + checkparams={ + "param_1": 10, + "price_1": 50, + "price_2": 45, + "price_3": 90, + }, + ) + def test_insert_uses_independent_cte(self): products = table("products", column("id"), column("price")) -- 2.47.3