]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
omit mysql8 dupe key alias for INSERT..FROM SELECT
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 Aug 2024 13:13:51 +0000 (09:13 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 Aug 2024 13:13:51 +0000 (09:13 -0400)
Fixed issue in MySQL dialect where using INSERT..FROM SELECT in combination
with ON DUPLICATE KEY UPDATE would erroneously render on MySQL 8 and above
the "AS new" clause, leading to syntax failures.  This clause is required
on MySQL 8 to follow the VALUES clause if use of the "new" alias is
present, however is not permitted to follow a FROM SELECT clause.

Fixes: #11731
Change-Id: I254a3db4e9dccd9a76b11fdfe6e38a064ba0b5cf

doc/build/changelog/unreleased_20/11731.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
test/dialect/mysql/test_compiler.py
test/dialect/mysql/test_on_duplicate.py

diff --git a/doc/build/changelog/unreleased_20/11731.rst b/doc/build/changelog/unreleased_20/11731.rst
new file mode 100644 (file)
index 0000000..34ab8b4
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, mysql
+    :tickets: 11731
+
+    Fixed issue in MySQL dialect where using INSERT..FROM SELECT in combination
+    with ON DUPLICATE KEY UPDATE would erroneously render on MySQL 8 and above
+    the "AS new" clause, leading to syntax failures.  This clause is required
+    on MySQL 8 to follow the VALUES clause if use of the "new" alias is
+    present, however is not permitted to follow a FROM SELECT clause.
+
index d5db02d2781bd371d63a5d51025cd97a287aee3b..aa99bf4d6849d24fdae11f5924810e02e748708f 100644 (file)
@@ -1349,7 +1349,7 @@ class MySQLCompiler(compiler.SQLCompiler):
 
         clauses = []
 
-        requires_mysql8_alias = (
+        requires_mysql8_alias = statement.select is None and (
             self.dialect._requires_alias_for_on_duplicate_key
         )
 
index 6712300aa40cd00bad4ea14b1ad2718dd1db2966..189390659add8df9c39b465ed5a9ae652895543e 100644 (file)
@@ -1127,6 +1127,31 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
 
         self.assert_compile(stmt, expected_sql, dialect=dialect)
 
+    @testing.variation("version", ["mysql8", "all_others"])
+    def test_from_select(self, version: Variation):
+        stmt = insert(self.table).from_select(
+            ["id", "bar"],
+            select(self.table.c.id, literal("bar2")),
+        )
+        stmt = stmt.on_duplicate_key_update(
+            bar=stmt.inserted.bar, baz=stmt.inserted.baz
+        )
+
+        expected_sql = (
+            "INSERT INTO foos (id, bar) SELECT foos.id, %s AS anon_1 "
+            "FROM foos "
+            "ON DUPLICATE KEY UPDATE bar = VALUES(bar), baz = VALUES(baz)"
+        )
+        if version.all_others:
+            dialect = None
+        elif version.mysql8:
+            dialect = mysql.dialect()
+            dialect._requires_alias_for_on_duplicate_key = True
+        else:
+            version.fail()
+
+        self.assert_compile(stmt, expected_sql, dialect=dialect)
+
     def test_from_literal(self):
         stmt = insert(self.table).values(
             [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}]
index 5a4e6ca8d5cc10cf6d9898404103c5cf709b99dc..35aebb470c30bf0976986a057f9fa5b3c2ac09b3 100644 (file)
@@ -3,6 +3,8 @@ from sqlalchemy import Column
 from sqlalchemy import exc
 from sqlalchemy import func
 from sqlalchemy import Integer
+from sqlalchemy import literal
+from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy.dialects.mysql import insert
@@ -63,6 +65,22 @@ class OnDuplicateTest(fixtures.TablesTest):
             [(1, "ab", "bz", False)],
         )
 
+    def test_on_duplicate_key_from_select(self, connection):
+        foos = self.tables.foos
+        conn = connection
+        conn.execute(insert(foos).values(dict(id=1, bar="b", baz="bz")))
+        stmt = insert(foos).from_select(
+            ["id", "bar", "baz"],
+            select(foos.c.id, literal("bar2"), literal("baz2")),
+        )
+        stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar)
+
+        conn.execute(stmt)
+        eq_(
+            conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
+            [(1, "bar2", "bz", False)],
+        )
+
     def test_on_duplicate_key_update_singlerow(self, connection):
         foos = self.tables.foos
         conn = connection