]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Dispatch independent ctes on compound select + test
authorEric Masseran <eric.masseran@gmail.com>
Wed, 4 Aug 2021 14:41:42 +0000 (16:41 +0200)
committerEric Masseran <eric.masseran@gmail.com>
Wed, 4 Aug 2021 14:41:42 +0000 (16:41 +0200)
lib/sqlalchemy/sql/compiler.py
test/sql/test_cte.py

index a81507acb94415f7f67f9b3656aece679e0ac875..1e090ef103fea7035d10d9c430bade1c7a975d6b 100644 (file)
@@ -1787,6 +1787,8 @@ class SQLCompiler(Compiled):
         if toplevel and not self.compile_state:
             self.compile_state = compile_state
 
+        compound_stmt = compile_state.statement
+
         entry = self._default_stack_entry if toplevel else self.stack[-1]
         need_result_map = toplevel or (
             not compound_index
@@ -1807,6 +1809,10 @@ class SQLCompiler(Compiled):
             }
         )
 
+        if compound_stmt._independent_ctes:
+            for cte in compound_stmt._independent_ctes:
+                cte._compiler_dispatch(self, **kwargs)
+
         keyword = self.compound_keywords.get(cs.keyword)
 
         text = (" " + keyword + " ").join(
index 77905cd896dbfd040c7f44fe7991e2983174c20a..cc663d53fbf300810294c53c4fd2f7b3fe56cb5d 100644 (file)
@@ -1533,6 +1533,37 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             checkparams={"param_1": 10, "price_1": 50, "price_2": 45},
         )
 
+    def test_compound_select_uses_independent_cte(self):
+        products = table("products", column("id"), column("price"))
+
+        upd_cte = (
+            products.update().values(price=10).where(products.c.price > 50)
+        ).cte()
+
+        stmt = (
+            products.select()
+            .where(products.c.price < 45)
+            .union(products.select().where(products.c.price > 90))
+            .add_cte(upd_cte)
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH anon_1 AS (UPDATE products SET price=:param_1 "
+            "WHERE products.price > :price_1) "
+            "SELECT products.id, products.price "
+            "FROM products WHERE products.price < :price_2 "
+            "UNION "
+            "SELECT products.id, products.price "
+            "FROM products WHERE products.price > :price_3",
+            checkparams={
+                "param_1": 10,
+                "price_1": 50,
+                "price_2": 45,
+                "price_3": 90,
+            },
+        )
+
     def test_insert_uses_independent_cte(self):
         products = table("products", column("id"), column("price"))