]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
use mysql 8 syntax for ON DUPLICATE KEY UPDATE
authorCaspar Wylie <casparwylie@Caspars-MacBook-Pro.local>
Fri, 3 Feb 2023 14:23:26 +0000 (09:23 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 6 Feb 2023 16:00:57 +0000 (11:00 -0500)
Added support for MySQL 8's new ``AS <name> ON DUPLICATE KEY`` syntax when
using :meth:`_mysql.Insert.on_duplicate_key_update`, which is required for
newer versions of MySQL 8 as the previous syntax using ``VALUES()`` now
emits a deprecation warning with those versions. Server version detection
is employed to determine if traditional MariaDB / MySQL < 8 ``VALUES()``
syntax should be used, vs. the newer MySQL 8 required syntax. Pull request
courtesy Caspar Wylie.

Fixes: #8626
Closes: #9210
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9210
Pull-request-sha: 1c8dfbf0b4c439d9ca2c194524c47eb7239ee3c5

Change-Id: I42c463837af06bc15b60c534159804193df07f02

doc/build/changelog/unreleased_20/8626.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
test/dialect/mysql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_20/8626.rst b/doc/build/changelog/unreleased_20/8626.rst
new file mode 100644 (file)
index 0000000..c12e803
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, mysql
+    :tickets: 8626
+
+    Added support for MySQL 8's new ``AS <name> ON DUPLICATE KEY`` syntax when
+    using :meth:`_mysql.Insert.on_duplicate_key_update`, which is required for
+    newer versions of MySQL 8 as the previous syntax using ``VALUES()`` now
+    emits a deprecation warning with those versions. Server version detection
+    is employed to determine if traditional MariaDB / MySQL < 8 ``VALUES()``
+    syntax should be used, vs. the newer MySQL 8 required syntax. Pull request
+    courtesy Caspar Wylie.
index 50e0ec07eaee43cfabf37d1b95dca4f86d6ac905..87fdabff5f2ad172bbba996862cf712d1db142fa 100644 (file)
@@ -1080,7 +1080,6 @@ SET_RE = re.compile(
     r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE
 )
 
-
 # old names
 MSTime = TIME
 MSSet = SET
@@ -1316,9 +1315,19 @@ class MySQLCompiler(compiler.SQLCompiler):
             cols = statement.table.c
 
         clauses = []
+
+        requires_mysql8_alias = (
+            self.dialect._requires_alias_for_on_duplicate_key
+        )
+
+        if requires_mysql8_alias:
+            if statement.table.name.lower() == "new":
+                _on_dup_alias_name = "new_1"
+            else:
+                _on_dup_alias_name = "new"
+
         # traverses through all table columns to preserve table column order
         for column in (col for col in cols if col.key in on_duplicate.update):
-
             val = on_duplicate.update[column.key]
 
             if coercions._is_literal(val):
@@ -1338,10 +1347,16 @@ class MySQLCompiler(compiler.SQLCompiler):
                         isinstance(obj, elements.ColumnClause)
                         and obj.table is on_duplicate.inserted_alias
                     ):
-                        obj = literal_column(
-                            "VALUES(" + self.preparer.quote(obj.name) + ")"
-                        )
-                        return obj
+                        if requires_mysql8_alias:
+                            column_literal_clause = (
+                                f"{_on_dup_alias_name}."
+                                f"{self.preparer.quote(obj.name)}"
+                            )
+                        else:
+                            column_literal_clause = (
+                                f"VALUES({self.preparer.quote(obj.name)})"
+                            )
+                        return literal_column(column_literal_clause)
                     else:
                         # element is not replaced
                         return None
@@ -1363,7 +1378,13 @@ class MySQLCompiler(compiler.SQLCompiler):
                 )
             )
 
-        return "ON DUPLICATE KEY UPDATE " + ", ".join(clauses)
+        if requires_mysql8_alias:
+            return (
+                f"AS {_on_dup_alias_name} "
+                f"ON DUPLICATE KEY UPDATE {', '.join(clauses)}"
+            )
+        else:
+            return f"ON DUPLICATE KEY UPDATE {', '.join(clauses)}"
 
     def visit_concat_op_expression_clauselist(
         self, clauselist, operator, **kw
@@ -2391,6 +2412,9 @@ class MySQLDialect(default.DefaultDialect):
     supports_for_update_of = False  # default for MySQL ...
     # ... may be updated to True for MySQL 8+ in initialize()
 
+    _requires_alias_for_on_duplicate_key = False  # Only available ...
+    # ... in MySQL 8+
+
     # MySQL doesn't support "DEFAULT VALUES" but *does* support
     # "VALUES (DEFAULT)"
     supports_default_values = False
@@ -2783,6 +2807,10 @@ class MySQLDialect(default.DefaultDialect):
             self.is_mariadb and self.server_version_info >= (10, 5)
         )
 
+        self._requires_alias_for_on_duplicate_key = (
+            self._is_mysql and self.server_version_info >= (8, 0, 20)
+        )
+
         self._warn_for_known_db_issues()
 
     def _warn_for_known_db_issues(self):
index 414f73ad7650f44d15e91f8939072de78c2afa3c..52d4529aecd3f2ace37e71a31f981118d8dd41b2 100644 (file)
@@ -63,6 +63,7 @@ from sqlalchemy.testing import eq_ignore_whitespace
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
+from sqlalchemy.testing import Variation
 
 
 class ReservedWordFixture(AssertsCompiledSQL):
@@ -1100,18 +1101,34 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
                 bar=stmt.inserted.bar, baz=stmt.inserted.baz
             )
 
-    def test_from_values(self):
+    @testing.variation("version", ["mysql8", "all_others"])
+    def test_from_values(self, version: Variation):
         stmt = insert(self.table).values(
             [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}]
         )
         stmt = stmt.on_duplicate_key_update(
             bar=stmt.inserted.bar, baz=stmt.inserted.baz
         )
-        expected_sql = (
-            "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) "
-            "ON DUPLICATE KEY UPDATE bar = VALUES(bar), baz = VALUES(baz)"
-        )
-        self.assert_compile(stmt, expected_sql)
+
+        if version.all_others:
+            expected_sql = (
+                "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) "
+                "ON DUPLICATE KEY UPDATE bar = VALUES(bar), baz = VALUES(baz)"
+            )
+            dialect = None
+        elif version.mysql8:
+            expected_sql = (
+                "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) "
+                "AS new ON DUPLICATE KEY UPDATE "
+                "bar = new.bar, "
+                "baz = new.baz"
+            )
+            dialect = mysql.dialect()
+            dialect._requires_alias_for_on_duplicate_key = True
+        else:
+            version.fail()
+
+        self.assert_compile(stmt, expected_sql, dialect=dialect)
 
     def test_from_literal(self):
         stmt = insert(self.table).values(
@@ -1135,7 +1152,8 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
         )
         self.assert_compile(stmt, expected_sql)
 
-    def test_update_sql_expr(self):
+    @testing.variation("version", ["mysql8", "all_others"])
+    def test_update_sql_expr(self, version: Variation):
         stmt = insert(self.table).values(
             [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}]
         )
@@ -1143,11 +1161,60 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
             bar=func.coalesce(stmt.inserted.bar),
             baz=stmt.inserted.baz + "some literal" + stmt.inserted.bar,
         )
+
+        if version.all_others:
+            expected_sql = (
+                "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) ON "
+                "DUPLICATE KEY UPDATE bar = coalesce(VALUES(bar)), "
+                "baz = (concat(VALUES(baz), %s, VALUES(bar)))"
+            )
+            dialect = None
+        elif version.mysql8:
+
+            expected_sql = (
+                "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) "
+                "AS new ON DUPLICATE KEY UPDATE bar = "
+                "coalesce(new.bar), "
+                "baz = (concat(new.baz, %s, "
+                "new.bar))"
+            )
+            dialect = mysql.dialect()
+            dialect._requires_alias_for_on_duplicate_key = True
+        else:
+            version.fail()
+
+        self.assert_compile(
+            stmt,
+            expected_sql,
+            checkparams={
+                "id_m0": 1,
+                "bar_m0": "ab",
+                "id_m1": 2,
+                "bar_m1": "b",
+                "baz_1": "some literal",
+            },
+            dialect=dialect,
+        )
+
+    def test_mysql8_on_update_dont_dup_alias_name(self):
+        t = table("new", column("id"), column("bar"), column("baz"))
+        stmt = insert(t).values(
+            [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}]
+        )
+        stmt = stmt.on_duplicate_key_update(
+            bar=func.coalesce(stmt.inserted.bar),
+            baz=stmt.inserted.baz + "some literal" + stmt.inserted.bar,
+        )
+
         expected_sql = (
-            "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) ON "
-            "DUPLICATE KEY UPDATE bar = coalesce(VALUES(bar)), "
-            "baz = (concat(VALUES(baz), %s, VALUES(bar)))"
+            "INSERT INTO new (id, bar) VALUES (%s, %s), (%s, %s) "
+            "AS new_1 ON DUPLICATE KEY UPDATE bar = "
+            "coalesce(new_1.bar), "
+            "baz = (concat(new_1.baz, %s, "
+            "new_1.bar))"
         )
+        dialect = mysql.dialect()
+        dialect._requires_alias_for_on_duplicate_key = True
         self.assert_compile(
             stmt,
             expected_sql,
@@ -1158,6 +1225,7 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
                 "bar_m1": "b",
                 "baz_1": "some literal",
             },
+            dialect=dialect,
         )