From eb1bb84fbc10c801c7269a3d38c9e0235327857e Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 8 May 2015 12:37:55 -0400 Subject: [PATCH] - Added official support for a CTE used by the SELECT present inside of :meth:`.Insert.from_select`. This behavior worked accidentally up until 0.9.9, when it no longer worked due to unrelated changes as part of :ticket:`3248`. Note that this is the rendering of the WITH clause after the INSERT, before the SELECT; the full functionality of CTEs rendered at the top level of INSERT, UPDATE, DELETE is a new feature targeted for a later release. fixes #3418 --- doc/build/changelog/changelog_09.rst | 14 +++++++++++ lib/sqlalchemy/sql/compiler.py | 16 ++++++++++++- test/sql/test_insert.py | 35 ++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 1 deletion(-) diff --git a/doc/build/changelog/changelog_09.rst b/doc/build/changelog/changelog_09.rst index 2506d21bdb..2d2964ba44 100644 --- a/doc/build/changelog/changelog_09.rst +++ b/doc/build/changelog/changelog_09.rst @@ -14,6 +14,20 @@ .. changelog:: :version: 0.9.10 + .. change:: + :tags: feature, sql + :tickets: 3418 + :versions: 1.0.5 + + Added official support for a CTE used by the SELECT present + inside of :meth:`.Insert.from_select`. This behavior worked + accidentally up until 0.9.9, when it no longer worked due to + unrelated changes as part of :ticket:`3248`. Note that this + is the rendering of the WITH clause after the INSERT, before the + SELECT; the full functionality of CTEs rendered at the top + level of INSERT, UPDATE, DELETE is a new feature targeted for a + later release. + .. change:: :tags: bug, ext :tickets: 3408 diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index c9c7fd2a15..e9c3d0efab 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1613,7 +1613,7 @@ class SQLCompiler(Compiled): if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) - if self.ctes and toplevel: + if self.ctes and self._is_toplevel_select(select): text = self._render_cte_clause() + text if select._suffixes: @@ -1627,6 +1627,20 @@ class SQLCompiler(Compiled): else: return text + def _is_toplevel_select(self, select): + """Return True if the stack is placed at the given select, and + is also the outermost SELECT, meaning there is either no stack + before this one, or the enclosing stack is a topmost INSERT. + + """ + return ( + self.stack[-1]['selectable'] is select and + ( + len(self.stack) == 1 or self.isinsert and len(self.stack) == 2 + and self.statement is self.stack[0]['selectable'] + ) + ) + def _setup_select_hints(self, select): byfrom = dict([ (from_, hinttext % { diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py index 7170fcbcb3..3c533d75fd 100644 --- a/test/sql/test_insert.py +++ b/test/sql/test_insert.py @@ -176,6 +176,41 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): checkparams={"name_1": "foo"} ) + def test_insert_from_select_cte_one(self): + table1 = self.tables.mytable + + cte = select([table1.c.name]).where(table1.c.name == 'bar').cte() + + sel = select([table1.c.myid, table1.c.name]).where( + table1.c.name == cte.c.name) + + ins = self.tables.myothertable.insert().\ + from_select(("otherid", "othername"), sel) + self.assert_compile( + ins, + "INSERT INTO myothertable (otherid, othername) WITH anon_1 AS " + "(SELECT mytable.name AS name FROM mytable " + "WHERE mytable.name = :name_1) " + "SELECT mytable.myid, mytable.name FROM mytable, anon_1 " + "WHERE mytable.name = anon_1.name", + checkparams={"name_1": "bar"} + ) + + def test_insert_from_select_cte_two(self): + table1 = self.tables.mytable + + cte = table1.select().cte("c") + stmt = cte.select() + ins = table1.insert().from_select(table1.c, stmt) + + self.assert_compile( + ins, + "INSERT INTO mytable (myid, name, description) " + "WITH c AS (SELECT mytable.myid AS myid, mytable.name AS name, " + "mytable.description AS description FROM mytable) " + "SELECT c.myid, c.name, c.description FROM c" + ) + def test_insert_from_select_select_alt_ordering(self): table1 = self.tables.mytable sel = select([table1.c.name, table1.c.myid]).where( -- 2.47.3