]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
render WITH clause after INSERT for INSERT..SELECT on Oracle, MySQL
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Jun 2018 02:17:00 +0000 (22:17 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 Jun 2018 13:12:19 +0000 (09:12 -0400)
Fixed INSERT FROM SELECT with CTEs for the Oracle and MySQL dialects, where
the CTE was being placed above the entire statement as is typical with
other databases, however Oracle and MariaDB 10.2 wants the CTE underneath
the "INSERT" segment. Note that the Oracle and MySQL dialects don't yet
work when a CTE is applied to a subquery inside of an UPDATE or DELETE
statement, as the CTE is still applied to the top rather than inside the
subquery.

Also adds test suite support CTEs against backends.

Change-Id: I8ac337104d5c546dd4f0cd305632ffb56ac8bf90
Fixes: #4275
Fixes: #4230
12 files changed:
doc/build/changelog/unreleased_12/4275.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/__init__.py
lib/sqlalchemy/testing/suite/test_cte.py [new file with mode: 0644]
lib/sqlalchemy/testing/suite/test_select.py
test/requirements.py
test/sql/test_defaults.py
test/sql/test_insert.py

diff --git a/doc/build/changelog/unreleased_12/4275.rst b/doc/build/changelog/unreleased_12/4275.rst
new file mode 100644 (file)
index 0000000..8d18be5
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, oracle, mysql
+    :tickets: 4275
+    :versions: 1.3.0b1
+
+    Fixed INSERT FROM SELECT with CTEs for the Oracle and MySQL dialects, where
+    the CTE was being placed above the entire statement as is typical with
+    other databases, however Oracle and MariaDB 10.2 wants the CTE underneath
+    the "INSERT" segment. Note that the Oracle and MySQL dialects don't yet
+    work when a CTE is applied to a subquery inside of an UPDATE or DELETE
+    statement, as the CTE is still applied to the top rather than inside the
+    subquery.
+
index c8a3d3322538a16e8a247de72726a131c7a70c87..62753e1a5c90decf0f268df4db96efa402e84ee3 100644 (file)
@@ -1684,6 +1684,8 @@ class MySQLDialect(default.DefaultDialect):
     default_paramstyle = 'format'
     colspecs = colspecs
 
+    cte_follows_insert = True
+
     statement_compiler = MySQLCompiler
     ddl_compiler = MySQLDDLCompiler
     type_compiler = MySQLTypeCompiler
index 39acbf28d8507630612ac25ffa6f99da28ad2ac7..356c2a2bf15942db87520280c8a7005eacee118b 100644 (file)
@@ -1030,6 +1030,7 @@ class OracleDialect(default.DefaultDialect):
     max_identifier_length = 30
 
     supports_simple_order_by_label = False
+    cte_follows_insert = True
 
     supports_sequences = True
     sequences_optional = False
index 4d5f338bf23e7f3cabf0640c408051f9bb8c5e29..54fb25c16bd570ea934bc93565f29452fd28b106 100644 (file)
@@ -60,6 +60,7 @@ class DefaultDialect(interfaces.Dialect):
     implicit_returning = False
 
     supports_right_nested_joins = True
+    cte_follows_insert = False
 
     supports_native_enum = False
     supports_native_boolean = False
index a442c65fd6d0fffe6eb878c167cdc55b7f502282..0b98dc51c616f27255d9258e689e969d5aa7807c 100644 (file)
@@ -2105,7 +2105,12 @@ class SQLCompiler(Compiled):
             returning_clause = None
 
         if insert_stmt.select is not None:
-            text += " %s" % self.process(self._insert_from_select, **kw)
+            select_text = self.process(self._insert_from_select, **kw)
+
+            if self.ctes and toplevel and self.dialect.cte_follows_insert:
+                text += " %s%s" % (self._render_cte_clause(), select_text)
+            else:
+                text += " %s" % select_text
         elif not crud_params and supports_default_values:
             text += " DEFAULT VALUES"
         elif insert_stmt._has_multi_parameters:
@@ -2130,7 +2135,7 @@ class SQLCompiler(Compiled):
         if returning_clause and not self.returning_precedes_values:
             text += " " + returning_clause
 
-        if self.ctes and toplevel:
+        if self.ctes and toplevel and not self.dialect.cte_follows_insert:
             text = self._render_cte_clause() + text
 
         self.stack.pop(-1)
index b509c94d6100e85f45f9cbd7fa025d6538112de8..19d80e02862c35ab8922c0aacc82f49df1e04ca6 100644 (file)
@@ -179,10 +179,19 @@ class SuiteRequirements(Requirements):
 
         return exclusions.closed()
 
+    @property
+    def ctes_with_update_delete(self):
+        """target database supports CTES that ride on top of a normal UPDATE
+        or DELETE statement which refers to the CTE in a correlated subquery.
+
+        """
+
+        return exclusions.closed()
+
     @property
     def ctes_on_dml(self):
         """target database supports CTES which consist of INSERT, UPDATE
-        or DELETE"""
+        or DELETE *within* the CTE, e.g. WITH x AS (UPDATE....)"""
 
         return exclusions.closed()
 
index 9eeffd4cb0e6d92b0e01dfc2b8e33188729f7d64..748d9722d6d7e2d3c08a2383ba022ab0dadc488c 100644 (file)
@@ -1,4 +1,5 @@
 
+from sqlalchemy.testing.suite.test_cte import *
 from sqlalchemy.testing.suite.test_dialect import *
 from sqlalchemy.testing.suite.test_ddl import *
 from sqlalchemy.testing.suite.test_insert import *
diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py
new file mode 100644 (file)
index 0000000..cc72278
--- /dev/null
@@ -0,0 +1,193 @@
+from .. import fixtures, config
+from ..assertions import eq_
+
+from sqlalchemy import Integer, String, select
+from sqlalchemy import ForeignKey
+from sqlalchemy import testing
+
+from ..schema import Table, Column
+
+
+class CTETest(fixtures.TablesTest):
+    __backend__ = True
+    __requires__ = 'ctes',
+
+    run_inserts = 'each'
+    run_deletes = 'each'
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table("some_table", metadata,
+              Column('id', Integer, primary_key=True),
+              Column('data', String(50)),
+              Column("parent_id", ForeignKey("some_table.id")))
+
+        Table("some_other_table", metadata,
+              Column('id', Integer, primary_key=True),
+              Column('data', String(50)),
+              Column("parent_id", Integer))
+
+    @classmethod
+    def insert_data(cls):
+        config.db.execute(
+            cls.tables.some_table.insert(),
+            [
+                {"id": 1, "data": "d1", "parent_id": None},
+                {"id": 2, "data": "d2", "parent_id": 1},
+                {"id": 3, "data": "d3", "parent_id": 1},
+                {"id": 4, "data": "d4", "parent_id": 3},
+                {"id": 5, "data": "d5", "parent_id": 3}
+            ]
+        )
+
+    def test_select_nonrecursive_round_trip(self):
+        some_table = self.tables.some_table
+
+        with config.db.connect() as conn:
+            cte = select([some_table]).where(
+                some_table.c.data.in_(["d2", "d3", "d4"])).cte("some_cte")
+            result = conn.execute(
+                select([cte.c.data]).where(cte.c.data.in_(["d4", "d5"]))
+            )
+            eq_(result.fetchall(), [("d4", )])
+
+    def test_select_recursive_round_trip(self):
+        some_table = self.tables.some_table
+
+        with config.db.connect() as conn:
+            cte = select([some_table]).where(
+                some_table.c.data.in_(["d2", "d3", "d4"])).cte(
+                "some_cte", recursive=True)
+
+            cte_alias = cte.alias("c1")
+            st1 = some_table.alias()
+            # note that SQL Server requires this to be UNION ALL,
+            # can't be UNION
+            cte = cte.union_all(
+                select([st1]).where(st1.c.id == cte_alias.c.parent_id)
+            )
+            result = conn.execute(
+                select([cte.c.data]).where(
+                    cte.c.data != "d2").order_by(cte.c.data.desc())
+            )
+            eq_(
+                result.fetchall(),
+                [('d4',), ('d3',), ('d3',), ('d1',), ('d1',), ('d1',)]
+            )
+
+    def test_insert_from_select_round_trip(self):
+        some_table = self.tables.some_table
+        some_other_table = self.tables.some_other_table
+
+        with config.db.connect() as conn:
+            cte = select([some_table]).where(
+                some_table.c.data.in_(["d2", "d3", "d4"])
+            ).cte("some_cte")
+            conn.execute(
+                some_other_table.insert().from_select(
+                    ["id", "data", "parent_id"],
+                    select([cte])
+                )
+            )
+            eq_(
+                conn.execute(
+                    select([some_other_table]).order_by(some_other_table.c.id)
+                ).fetchall(),
+                [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)]
+            )
+
+    @testing.requires.ctes_with_update_delete
+    @testing.requires.update_from
+    def test_update_from_round_trip(self):
+        some_table = self.tables.some_table
+        some_other_table = self.tables.some_other_table
+
+        with config.db.connect() as conn:
+            conn.execute(
+                some_other_table.insert().from_select(
+                    ['id', 'data', 'parent_id'],
+                    select([some_table])
+                )
+            )
+
+            cte = select([some_table]).where(
+                some_table.c.data.in_(["d2", "d3", "d4"])
+            ).cte("some_cte")
+            conn.execute(
+                some_other_table.update().values(parent_id=5).where(
+                    some_other_table.c.data == cte.c.data
+                )
+            )
+            eq_(
+                conn.execute(
+                    select([some_other_table]).order_by(some_other_table.c.id)
+                ).fetchall(),
+                [
+                    (1, "d1", None), (2, "d2", 5),
+                    (3, "d3", 5), (4, "d4", 5), (5, "d5", 3)
+                ]
+            )
+
+    @testing.requires.ctes_with_update_delete
+    @testing.requires.delete_from
+    def test_delete_from_round_trip(self):
+        some_table = self.tables.some_table
+        some_other_table = self.tables.some_other_table
+
+        with config.db.connect() as conn:
+            conn.execute(
+                some_other_table.insert().from_select(
+                    ['id', 'data', 'parent_id'],
+                    select([some_table])
+                )
+            )
+
+            cte = select([some_table]).where(
+                some_table.c.data.in_(["d2", "d3", "d4"])
+            ).cte("some_cte")
+            conn.execute(
+                some_other_table.delete().where(
+                    some_other_table.c.data == cte.c.data
+                )
+            )
+            eq_(
+                conn.execute(
+                    select([some_other_table]).order_by(some_other_table.c.id)
+                ).fetchall(),
+                [
+                    (1, "d1", None), (5, "d5", 3)
+                ]
+            )
+
+    @testing.requires.ctes_with_update_delete
+    def test_delete_scalar_subq_round_trip(self):
+
+        some_table = self.tables.some_table
+        some_other_table = self.tables.some_other_table
+
+        with config.db.connect() as conn:
+            conn.execute(
+                some_other_table.insert().from_select(
+                    ['id', 'data', 'parent_id'],
+                    select([some_table])
+                )
+            )
+
+            cte = select([some_table]).where(
+                some_table.c.data.in_(["d2", "d3", "d4"])
+            ).cte("some_cte")
+            conn.execute(
+                some_other_table.delete().where(
+                    some_other_table.c.data ==
+                    select([cte.c.data]).where(
+                        cte.c.id == some_other_table.c.id)
+                )
+            )
+            eq_(
+                conn.execute(
+                    select([some_other_table]).order_by(some_other_table.c.id)
+                ).fetchall(),
+                [
+                    (1, "d1", None), (5, "d5", 3)
+                ]
+            )
index d9755c8f972c6b38629dc6234d5aebd56d61c7e7..05b9162de572d9e9c843d3b01d9f972c67ab457d 100644 (file)
@@ -511,3 +511,5 @@ class LikeFunctionsTest(fixtures.TablesTest):
         col = self.tables.some_table.c.data
         self._test(col.contains("b%cd", autoescape=True, escape="#"), {3})
         self._test(col.contains("b#cd", autoescape=True, escape="#"), {7})
+
+
index 4a53b76ecb6904c1a308b34d6196e912540ff0ec..c1e30daf6d815920b118d70e354d2ce7972cf80c 100644 (file)
@@ -348,7 +348,7 @@ class DefaultRequirements(SuiteRequirements):
     def delete_from(self):
         """Target must support DELETE FROM..FROM or DELETE..USING syntax"""
         return only_on(['postgresql', 'mssql', 'mysql', 'sybase'],
-                       "Backend does not support UPDATE..FROM")
+                       "Backend does not support DELETE..FROM")
 
     @property
     def update_where_target_in_subquery(self):
@@ -466,14 +466,34 @@ class DefaultRequirements(SuiteRequirements):
     def ctes(self):
         """Target database supports CTEs"""
 
-        return only_if(
-            ['postgresql', 'mssql']
-        )
+        return only_on([
+            lambda config: against(config, "mysql") and (
+                config.db.dialect._is_mariadb and
+                config.db.dialect._mariadb_normalized_version_info >=
+                (10, 2)
+            ),
+            "postgresql",
+            "mssql",
+            "oracle"
+        ])
+
+    @property
+    def ctes_with_update_delete(self):
+        """target database supports CTES that ride on top of a normal UPDATE
+        or DELETE statement which refers to the CTE in a correlated subquery.
+
+        """
+        return only_on([
+            "postgresql",
+            "mssql",
+            # "oracle" - oracle can do this but SQLAlchemy doesn't support
+            # their syntax yet
+        ])
 
     @property
     def ctes_on_dml(self):
         """target database supports CTES which consist of INSERT, UPDATE
-        or DELETE"""
+        or DELETE *within* the CTE, e.g. WITH x AS (UPDATE....)"""
 
         return only_if(
             ['postgresql']
index fc42d420f84236cdab5ff6b9f8f2237e09b5d1d1..c53670a05f004ec0efffaf59fb7fbe4c9bc3618f 100644 (file)
@@ -568,7 +568,7 @@ class DefaultTest(fixtures.TestBase):
 
 
 class CTEDefaultTest(fixtures.TablesTest):
-    __requires__ = ('ctes',)
+    __requires__ = ('ctes', 'returning', 'ctes_on_dml')
     __backend__ = True
 
     @classmethod
index 6d41a4dca565a46f90f9dba24acf9b50ba158462..6ea5b4f37c90213f15fc2ef0413aae632a139498 100644 (file)
@@ -278,6 +278,31 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             checkparams={"name_1": "bar"}
         )
 
+    def test_insert_from_select_cte_follows_insert_one(self):
+        dialect = default.DefaultDialect()
+        dialect.cte_follows_insert = True
+
+        table1 = self.tables.mytable
+
+        cte = select([table1.c.name]).where(table1.c.name == 'bar').cte()
+
+        sel = select([table1.c.myid, table1.c.name]).where(
+            table1.c.name == cte.c.name)
+
+        ins = self.tables.myothertable.insert().\
+            from_select(("otherid", "othername"), sel)
+        self.assert_compile(
+            ins,
+            "INSERT INTO myothertable (otherid, othername) "
+            "WITH anon_1 AS "
+            "(SELECT mytable.name AS name FROM mytable "
+            "WHERE mytable.name = :name_1) "
+            "SELECT mytable.myid, mytable.name FROM mytable, anon_1 "
+            "WHERE mytable.name = anon_1.name",
+            checkparams={"name_1": "bar"},
+            dialect=dialect
+        )
+
     def test_insert_from_select_cte_two(self):
         table1 = self.tables.mytable
 
@@ -293,6 +318,24 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             "SELECT c.myid, c.name, c.description FROM c"
         )
 
+    def test_insert_from_select_cte_follows_insert_two(self):
+        dialect = default.DefaultDialect()
+        dialect.cte_follows_insert = True
+        table1 = self.tables.mytable
+
+        cte = table1.select().cte("c")
+        stmt = cte.select()
+        ins = table1.insert().from_select(table1.c, stmt)
+
+        self.assert_compile(
+            ins,
+            "INSERT INTO mytable (myid, name, description) "
+            "WITH c AS (SELECT mytable.myid AS myid, mytable.name AS name, "
+            "mytable.description AS description FROM mytable) "
+            "SELECT c.myid, c.name, c.description FROM c",
+            dialect=dialect
+        )
+
     def test_insert_from_select_select_alt_ordering(self):
         table1 = self.tables.mytable
         sel = select([table1.c.name, table1.c.myid]).where(