From: Eric Masseran Date: Sat, 10 Jul 2021 22:39:53 +0000 (+0200) Subject: Support insert X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f55d8bfd83773755ca1605e851272ea03c7df6e5;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support insert --- diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 66a8db1eab..b55e8849ce 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2999,6 +2999,7 @@ class SQLCompiler(Compiled): self, select_stmt, asfrom=False, + insert_into=False, fromhints=None, compound_index=None, select_wraps_for=None, @@ -3160,7 +3161,7 @@ class SQLCompiler(Compiled): text += " " + self.get_statement_hint_text(per_dialect) # In compound query, CTEs are shared at the compound level - if self.ctes and compound_index is None: + if self.ctes and compound_index is None and not insert_into: nesting_level = len(self.stack) if not toplevel else None text = self._render_cte_clause(nesting_level=nesting_level) + text @@ -3613,10 +3614,16 @@ class SQLCompiler(Compiled): returning_clause = None if insert_stmt.select is not None: - select_text = self.process(self._insert_from_select, **kw) + select_text = self.process( + self._insert_from_select, insert_into=True, **kw + ) - if self.ctes and toplevel and self.dialect.cte_follows_insert: - text += " %s%s" % (self._render_cte_clause(), select_text) + if self.ctes and self.dialect.cte_follows_insert: + nesting_level = (len(self.stack) + 1) if not toplevel else None + text += " %s%s" % ( + self._render_cte_clause(nesting_level=nesting_level), + select_text, + ) else: text += " %s" % select_text elif not crud_params and supports_default_values: @@ -3654,8 +3661,9 @@ class SQLCompiler(Compiled): if returning_clause and not self.returning_precedes_values: text += " " + returning_clause - if self.ctes and toplevel and not self.dialect.cte_follows_insert: - text = self._render_cte_clause() + text + if self.ctes and not self.dialect.cte_follows_insert: + nesting_level = (len(self.stack) + 1) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text self.stack.pop(-1) diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 234139ec69..bbe0df50df 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1585,3 +1585,41 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "dialect's statement compiler.", functools.partial(stmt.compile, dialect=dialect.dialect()), ) + + def test_nesting_cte_for_insert_in_the_cte(self): + products = table("products", column("id"), column("price")) + + generator_cte = select( + [literal(1).label("id"), literal(27.0).label("price")] + ).cte("generator", nesting=True) + + cte = ( + products.insert() + .from_select( + [products.c.id, products.c.price], + select([generator_cte]), + ) + .returning(*products.c) + .cte("pd") + ) + + stmt = select(cte) + + assert "autocommit" not in stmt._execution_options + + compiled = stmt.compile(dialect=self.__dialect__) + + eq_(compiled.execution_options["autocommit"], True) + + self.assert_compile( + stmt, + "WITH pd AS " + "(WITH generator AS " + "(SELECT %(param_1)s AS id, %(param_2)s AS price) " + "INSERT INTO products (id, price) " + "SELECT generator.id AS id, generator.price AS price FROM generator " + "RETURNING products.id, products.price) " + "SELECT pd.id, pd.price " + "FROM pd", + ) + eq_(compiled.isinsert, False)