From 3619edcb8aa3ceef2a44925b85315fc0e90c5982 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 14 Jun 2018 22:17:00 -0400 Subject: [PATCH] render WITH clause after INSERT for INSERT..SELECT on Oracle, MySQL Fixed INSERT FROM SELECT with CTEs for the Oracle and MySQL dialects, where the CTE was being placed above the entire statement as is typical with other databases, however Oracle and MariaDB 10.2 wants the CTE underneath the "INSERT" segment. Note that the Oracle and MySQL dialects don't yet work when a CTE is applied to a subquery inside of an UPDATE or DELETE statement, as the CTE is still applied to the top rather than inside the subquery. Also adds test suite support CTEs against backends. Change-Id: I8ac337104d5c546dd4f0cd305632ffb56ac8bf90 Fixes: #4275 Fixes: #4230 --- doc/build/changelog/unreleased_12/4275.rst | 13 ++ lib/sqlalchemy/dialects/mysql/base.py | 2 + lib/sqlalchemy/dialects/oracle/base.py | 1 + lib/sqlalchemy/engine/default.py | 1 + lib/sqlalchemy/sql/compiler.py | 9 +- lib/sqlalchemy/testing/requirements.py | 11 +- lib/sqlalchemy/testing/suite/__init__.py | 1 + lib/sqlalchemy/testing/suite/test_cte.py | 193 ++++++++++++++++++++ lib/sqlalchemy/testing/suite/test_select.py | 2 + test/requirements.py | 30 ++- test/sql/test_defaults.py | 2 +- test/sql/test_insert.py | 43 +++++ 12 files changed, 299 insertions(+), 9 deletions(-) create mode 100644 doc/build/changelog/unreleased_12/4275.rst create mode 100644 lib/sqlalchemy/testing/suite/test_cte.py diff --git a/doc/build/changelog/unreleased_12/4275.rst b/doc/build/changelog/unreleased_12/4275.rst new file mode 100644 index 0000000000..8d18be5049 --- /dev/null +++ b/doc/build/changelog/unreleased_12/4275.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: bug, oracle, mysql + :tickets: 4275 + :versions: 1.3.0b1 + + Fixed INSERT FROM SELECT with CTEs for the Oracle and MySQL dialects, where + the CTE was being placed above the entire statement as is typical with + other databases, however Oracle and MariaDB 10.2 wants the CTE underneath + the "INSERT" segment. Note that the Oracle and MySQL dialects don't yet + work when a CTE is applied to a subquery inside of an UPDATE or DELETE + statement, as the CTE is still applied to the top rather than inside the + subquery. + diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index c8a3d33225..62753e1a5c 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1684,6 +1684,8 @@ class MySQLDialect(default.DefaultDialect): default_paramstyle = 'format' colspecs = colspecs + cte_follows_insert = True + statement_compiler = MySQLCompiler ddl_compiler = MySQLDDLCompiler type_compiler = MySQLTypeCompiler diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 39acbf28d8..356c2a2bf1 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1030,6 +1030,7 @@ class OracleDialect(default.DefaultDialect): max_identifier_length = 30 supports_simple_order_by_label = False + cte_follows_insert = True supports_sequences = True sequences_optional = False diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 4d5f338bf2..54fb25c16b 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -60,6 +60,7 @@ class DefaultDialect(interfaces.Dialect): implicit_returning = False supports_right_nested_joins = True + cte_follows_insert = False supports_native_enum = False supports_native_boolean = False diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index a442c65fd6..0b98dc51c6 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2105,7 +2105,12 @@ class SQLCompiler(Compiled): returning_clause = None if insert_stmt.select is not None: - text += " %s" % self.process(self._insert_from_select, **kw) + select_text = self.process(self._insert_from_select, **kw) + + if self.ctes and toplevel and self.dialect.cte_follows_insert: + text += " %s%s" % (self._render_cte_clause(), select_text) + else: + text += " %s" % select_text elif not crud_params and supports_default_values: text += " DEFAULT VALUES" elif insert_stmt._has_multi_parameters: @@ -2130,7 +2135,7 @@ class SQLCompiler(Compiled): if returning_clause and not self.returning_precedes_values: text += " " + returning_clause - if self.ctes and toplevel: + if self.ctes and toplevel and not self.dialect.cte_follows_insert: text = self._render_cte_clause() + text self.stack.pop(-1) diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index b509c94d61..19d80e0286 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -179,10 +179,19 @@ class SuiteRequirements(Requirements): return exclusions.closed() + @property + def ctes_with_update_delete(self): + """target database supports CTES that ride on top of a normal UPDATE + or DELETE statement which refers to the CTE in a correlated subquery. + + """ + + return exclusions.closed() + @property def ctes_on_dml(self): """target database supports CTES which consist of INSERT, UPDATE - or DELETE""" + or DELETE *within* the CTE, e.g. WITH x AS (UPDATE....)""" return exclusions.closed() diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py index 9eeffd4cb0..748d9722d6 100644 --- a/lib/sqlalchemy/testing/suite/__init__.py +++ b/lib/sqlalchemy/testing/suite/__init__.py @@ -1,4 +1,5 @@ +from sqlalchemy.testing.suite.test_cte import * from sqlalchemy.testing.suite.test_dialect import * from sqlalchemy.testing.suite.test_ddl import * from sqlalchemy.testing.suite.test_insert import * diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py new file mode 100644 index 0000000000..cc72278e6c --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_cte.py @@ -0,0 +1,193 @@ +from .. import fixtures, config +from ..assertions import eq_ + +from sqlalchemy import Integer, String, select +from sqlalchemy import ForeignKey +from sqlalchemy import testing + +from ..schema import Table, Column + + +class CTETest(fixtures.TablesTest): + __backend__ = True + __requires__ = 'ctes', + + run_inserts = 'each' + run_deletes = 'each' + + @classmethod + def define_tables(cls, metadata): + Table("some_table", metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column("parent_id", ForeignKey("some_table.id"))) + + Table("some_other_table", metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column("parent_id", Integer)) + + @classmethod + def insert_data(cls): + config.db.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "data": "d1", "parent_id": None}, + {"id": 2, "data": "d2", "parent_id": 1}, + {"id": 3, "data": "d3", "parent_id": 1}, + {"id": 4, "data": "d4", "parent_id": 3}, + {"id": 5, "data": "d5", "parent_id": 3} + ] + ) + + def test_select_nonrecursive_round_trip(self): + some_table = self.tables.some_table + + with config.db.connect() as conn: + cte = select([some_table]).where( + some_table.c.data.in_(["d2", "d3", "d4"])).cte("some_cte") + result = conn.execute( + select([cte.c.data]).where(cte.c.data.in_(["d4", "d5"])) + ) + eq_(result.fetchall(), [("d4", )]) + + def test_select_recursive_round_trip(self): + some_table = self.tables.some_table + + with config.db.connect() as conn: + cte = select([some_table]).where( + some_table.c.data.in_(["d2", "d3", "d4"])).cte( + "some_cte", recursive=True) + + cte_alias = cte.alias("c1") + st1 = some_table.alias() + # note that SQL Server requires this to be UNION ALL, + # can't be UNION + cte = cte.union_all( + select([st1]).where(st1.c.id == cte_alias.c.parent_id) + ) + result = conn.execute( + select([cte.c.data]).where( + cte.c.data != "d2").order_by(cte.c.data.desc()) + ) + eq_( + result.fetchall(), + [('d4',), ('d3',), ('d3',), ('d1',), ('d1',), ('d1',)] + ) + + def test_insert_from_select_round_trip(self): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + with config.db.connect() as conn: + cte = select([some_table]).where( + some_table.c.data.in_(["d2", "d3", "d4"]) + ).cte("some_cte") + conn.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], + select([cte]) + ) + ) + eq_( + conn.execute( + select([some_other_table]).order_by(some_other_table.c.id) + ).fetchall(), + [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)] + ) + + @testing.requires.ctes_with_update_delete + @testing.requires.update_from + def test_update_from_round_trip(self): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + with config.db.connect() as conn: + conn.execute( + some_other_table.insert().from_select( + ['id', 'data', 'parent_id'], + select([some_table]) + ) + ) + + cte = select([some_table]).where( + some_table.c.data.in_(["d2", "d3", "d4"]) + ).cte("some_cte") + conn.execute( + some_other_table.update().values(parent_id=5).where( + some_other_table.c.data == cte.c.data + ) + ) + eq_( + conn.execute( + select([some_other_table]).order_by(some_other_table.c.id) + ).fetchall(), + [ + (1, "d1", None), (2, "d2", 5), + (3, "d3", 5), (4, "d4", 5), (5, "d5", 3) + ] + ) + + @testing.requires.ctes_with_update_delete + @testing.requires.delete_from + def test_delete_from_round_trip(self): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + with config.db.connect() as conn: + conn.execute( + some_other_table.insert().from_select( + ['id', 'data', 'parent_id'], + select([some_table]) + ) + ) + + cte = select([some_table]).where( + some_table.c.data.in_(["d2", "d3", "d4"]) + ).cte("some_cte") + conn.execute( + some_other_table.delete().where( + some_other_table.c.data == cte.c.data + ) + ) + eq_( + conn.execute( + select([some_other_table]).order_by(some_other_table.c.id) + ).fetchall(), + [ + (1, "d1", None), (5, "d5", 3) + ] + ) + + @testing.requires.ctes_with_update_delete + def test_delete_scalar_subq_round_trip(self): + + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + with config.db.connect() as conn: + conn.execute( + some_other_table.insert().from_select( + ['id', 'data', 'parent_id'], + select([some_table]) + ) + ) + + cte = select([some_table]).where( + some_table.c.data.in_(["d2", "d3", "d4"]) + ).cte("some_cte") + conn.execute( + some_other_table.delete().where( + some_other_table.c.data == + select([cte.c.data]).where( + cte.c.id == some_other_table.c.id) + ) + ) + eq_( + conn.execute( + select([some_other_table]).order_by(some_other_table.c.id) + ).fetchall(), + [ + (1, "d1", None), (5, "d5", 3) + ] + ) diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index d9755c8f97..05b9162de5 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -511,3 +511,5 @@ class LikeFunctionsTest(fixtures.TablesTest): col = self.tables.some_table.c.data self._test(col.contains("b%cd", autoescape=True, escape="#"), {3}) self._test(col.contains("b#cd", autoescape=True, escape="#"), {7}) + + diff --git a/test/requirements.py b/test/requirements.py index 4a53b76ecb..c1e30daf6d 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -348,7 +348,7 @@ class DefaultRequirements(SuiteRequirements): def delete_from(self): """Target must support DELETE FROM..FROM or DELETE..USING syntax""" return only_on(['postgresql', 'mssql', 'mysql', 'sybase'], - "Backend does not support UPDATE..FROM") + "Backend does not support DELETE..FROM") @property def update_where_target_in_subquery(self): @@ -466,14 +466,34 @@ class DefaultRequirements(SuiteRequirements): def ctes(self): """Target database supports CTEs""" - return only_if( - ['postgresql', 'mssql'] - ) + return only_on([ + lambda config: against(config, "mysql") and ( + config.db.dialect._is_mariadb and + config.db.dialect._mariadb_normalized_version_info >= + (10, 2) + ), + "postgresql", + "mssql", + "oracle" + ]) + + @property + def ctes_with_update_delete(self): + """target database supports CTES that ride on top of a normal UPDATE + or DELETE statement which refers to the CTE in a correlated subquery. + + """ + return only_on([ + "postgresql", + "mssql", + # "oracle" - oracle can do this but SQLAlchemy doesn't support + # their syntax yet + ]) @property def ctes_on_dml(self): """target database supports CTES which consist of INSERT, UPDATE - or DELETE""" + or DELETE *within* the CTE, e.g. WITH x AS (UPDATE....)""" return only_if( ['postgresql'] diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index fc42d420f8..c53670a05f 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -568,7 +568,7 @@ class DefaultTest(fixtures.TestBase): class CTEDefaultTest(fixtures.TablesTest): - __requires__ = ('ctes',) + __requires__ = ('ctes', 'returning', 'ctes_on_dml') __backend__ = True @classmethod diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py index 6d41a4dca5..6ea5b4f37c 100644 --- a/test/sql/test_insert.py +++ b/test/sql/test_insert.py @@ -278,6 +278,31 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): checkparams={"name_1": "bar"} ) + def test_insert_from_select_cte_follows_insert_one(self): + dialect = default.DefaultDialect() + dialect.cte_follows_insert = True + + table1 = self.tables.mytable + + cte = select([table1.c.name]).where(table1.c.name == 'bar').cte() + + sel = select([table1.c.myid, table1.c.name]).where( + table1.c.name == cte.c.name) + + ins = self.tables.myothertable.insert().\ + from_select(("otherid", "othername"), sel) + self.assert_compile( + ins, + "INSERT INTO myothertable (otherid, othername) " + "WITH anon_1 AS " + "(SELECT mytable.name AS name FROM mytable " + "WHERE mytable.name = :name_1) " + "SELECT mytable.myid, mytable.name FROM mytable, anon_1 " + "WHERE mytable.name = anon_1.name", + checkparams={"name_1": "bar"}, + dialect=dialect + ) + def test_insert_from_select_cte_two(self): table1 = self.tables.mytable @@ -293,6 +318,24 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): "SELECT c.myid, c.name, c.description FROM c" ) + def test_insert_from_select_cte_follows_insert_two(self): + dialect = default.DefaultDialect() + dialect.cte_follows_insert = True + table1 = self.tables.mytable + + cte = table1.select().cte("c") + stmt = cte.select() + ins = table1.insert().from_select(table1.c, stmt) + + self.assert_compile( + ins, + "INSERT INTO mytable (myid, name, description) " + "WITH c AS (SELECT mytable.myid AS myid, mytable.name AS name, " + "mytable.description AS description FROM mytable) " + "SELECT c.myid, c.name, c.description FROM c", + dialect=dialect + ) + def test_insert_from_select_select_alt_ordering(self): table1 = self.tables.mytable sel = select([table1.c.name, table1.c.myid]).where( -- 2.47.2