]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support insert
authorEric Masseran <eric.masseran@gmail.com>
Sat, 10 Jul 2021 22:39:53 +0000 (00:39 +0200)
committerEric Masseran <eric.masseran@gmail.com>
Sat, 10 Jul 2021 22:39:53 +0000 (00:39 +0200)
lib/sqlalchemy/sql/compiler.py
test/sql/test_cte.py

index 66a8db1eab1c5920e1298a6eb6470509be7728bd..b55e8849ce65e1c65d8eed25ccd18eb82a6af51c 100644 (file)
@@ -2999,6 +2999,7 @@ class SQLCompiler(Compiled):
         self,
         select_stmt,
         asfrom=False,
+        insert_into=False,
         fromhints=None,
         compound_index=None,
         select_wraps_for=None,
@@ -3160,7 +3161,7 @@ class SQLCompiler(Compiled):
                 text += " " + self.get_statement_hint_text(per_dialect)
 
         # In compound query, CTEs are shared at the compound level
-        if self.ctes and compound_index is None:
+        if self.ctes and compound_index is None and not insert_into:
             nesting_level = len(self.stack) if not toplevel else None
             text = self._render_cte_clause(nesting_level=nesting_level) + text
 
@@ -3613,10 +3614,16 @@ class SQLCompiler(Compiled):
             returning_clause = None
 
         if insert_stmt.select is not None:
-            select_text = self.process(self._insert_from_select, **kw)
+            select_text = self.process(
+                self._insert_from_select, insert_into=True, **kw
+            )
 
-            if self.ctes and toplevel and self.dialect.cte_follows_insert:
-                text += " %s%s" % (self._render_cte_clause(), select_text)
+            if self.ctes and self.dialect.cte_follows_insert:
+                nesting_level = (len(self.stack) + 1) if not toplevel else None
+                text += " %s%s" % (
+                    self._render_cte_clause(nesting_level=nesting_level),
+                    select_text,
+                )
             else:
                 text += " %s" % select_text
         elif not crud_params and supports_default_values:
@@ -3654,8 +3661,9 @@ class SQLCompiler(Compiled):
         if returning_clause and not self.returning_precedes_values:
             text += " " + returning_clause
 
-        if self.ctes and toplevel and not self.dialect.cte_follows_insert:
-            text = self._render_cte_clause() + text
+        if self.ctes and not self.dialect.cte_follows_insert:
+            nesting_level = (len(self.stack) + 1) if not toplevel else None
+            text = self._render_cte_clause(nesting_level=nesting_level) + text
 
         self.stack.pop(-1)
 
index 234139ec69d663c2c432781539a3dcd5b318768f..bbe0df50dfab2a4acc1c8758968e4bf0b56f67b3 100644 (file)
@@ -1585,3 +1585,41 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             "dialect's statement compiler.",
             functools.partial(stmt.compile, dialect=dialect.dialect()),
         )
+
+    def test_nesting_cte_for_insert_in_the_cte(self):
+        products = table("products", column("id"), column("price"))
+
+        generator_cte = select(
+            [literal(1).label("id"), literal(27.0).label("price")]
+        ).cte("generator", nesting=True)
+
+        cte = (
+            products.insert()
+            .from_select(
+                [products.c.id, products.c.price],
+                select([generator_cte]),
+            )
+            .returning(*products.c)
+            .cte("pd")
+        )
+
+        stmt = select(cte)
+
+        assert "autocommit" not in stmt._execution_options
+
+        compiled = stmt.compile(dialect=self.__dialect__)
+
+        eq_(compiled.execution_options["autocommit"], True)
+
+        self.assert_compile(
+            stmt,
+            "WITH pd AS "
+            "(WITH generator AS "
+            "(SELECT %(param_1)s AS id, %(param_2)s AS price) "
+            "INSERT INTO products (id, price) "
+            "SELECT generator.id AS id, generator.price AS price FROM generator "
+            "RETURNING products.id, products.price) "
+            "SELECT pd.id, pd.price "
+            "FROM pd",
+        )
+        eq_(compiled.isinsert, False)