From 2a9487fefcc915ae411a8edc48d5203619ed642b Mon Sep 17 00:00:00 2001 From: Caspar Wylie Date: Fri, 3 Feb 2023 09:23:26 -0500 Subject: [PATCH] use mysql 8 syntax for ON DUPLICATE KEY UPDATE Added support for MySQL 8's new ``AS ON DUPLICATE KEY`` syntax when using :meth:`_mysql.Insert.on_duplicate_key_update`, which is required for newer versions of MySQL 8 as the previous syntax using ``VALUES()`` now emits a deprecation warning with those versions. Server version detection is employed to determine if traditional MariaDB / MySQL < 8 ``VALUES()`` syntax should be used, vs. the newer MySQL 8 required syntax. Pull request courtesy Caspar Wylie. Fixes: #8626 Closes: #9210 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9210 Pull-request-sha: 1c8dfbf0b4c439d9ca2c194524c47eb7239ee3c5 Change-Id: I42c463837af06bc15b60c534159804193df07f02 --- doc/build/changelog/unreleased_20/8626.rst | 11 +++ lib/sqlalchemy/dialects/mysql/base.py | 42 +++++++++-- test/dialect/mysql/test_compiler.py | 88 +++++++++++++++++++--- 3 files changed, 124 insertions(+), 17 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/8626.rst diff --git a/doc/build/changelog/unreleased_20/8626.rst b/doc/build/changelog/unreleased_20/8626.rst new file mode 100644 index 0000000000..c12e803834 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8626.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, mysql + :tickets: 8626 + + Added support for MySQL 8's new ``AS ON DUPLICATE KEY`` syntax when + using :meth:`_mysql.Insert.on_duplicate_key_update`, which is required for + newer versions of MySQL 8 as the previous syntax using ``VALUES()`` now + emits a deprecation warning with those versions. Server version detection + is employed to determine if traditional MariaDB / MySQL < 8 ``VALUES()`` + syntax should be used, vs. the newer MySQL 8 required syntax. Pull request + courtesy Caspar Wylie. diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 50e0ec07ea..87fdabff5f 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1080,7 +1080,6 @@ SET_RE = re.compile( r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE ) - # old names MSTime = TIME MSSet = SET @@ -1316,9 +1315,19 @@ class MySQLCompiler(compiler.SQLCompiler): cols = statement.table.c clauses = [] + + requires_mysql8_alias = ( + self.dialect._requires_alias_for_on_duplicate_key + ) + + if requires_mysql8_alias: + if statement.table.name.lower() == "new": + _on_dup_alias_name = "new_1" + else: + _on_dup_alias_name = "new" + # traverses through all table columns to preserve table column order for column in (col for col in cols if col.key in on_duplicate.update): - val = on_duplicate.update[column.key] if coercions._is_literal(val): @@ -1338,10 +1347,16 @@ class MySQLCompiler(compiler.SQLCompiler): isinstance(obj, elements.ColumnClause) and obj.table is on_duplicate.inserted_alias ): - obj = literal_column( - "VALUES(" + self.preparer.quote(obj.name) + ")" - ) - return obj + if requires_mysql8_alias: + column_literal_clause = ( + f"{_on_dup_alias_name}." + f"{self.preparer.quote(obj.name)}" + ) + else: + column_literal_clause = ( + f"VALUES({self.preparer.quote(obj.name)})" + ) + return literal_column(column_literal_clause) else: # element is not replaced return None @@ -1363,7 +1378,13 @@ class MySQLCompiler(compiler.SQLCompiler): ) ) - return "ON DUPLICATE KEY UPDATE " + ", ".join(clauses) + if requires_mysql8_alias: + return ( + f"AS {_on_dup_alias_name} " + f"ON DUPLICATE KEY UPDATE {', '.join(clauses)}" + ) + else: + return f"ON DUPLICATE KEY UPDATE {', '.join(clauses)}" def visit_concat_op_expression_clauselist( self, clauselist, operator, **kw @@ -2391,6 +2412,9 @@ class MySQLDialect(default.DefaultDialect): supports_for_update_of = False # default for MySQL ... # ... may be updated to True for MySQL 8+ in initialize() + _requires_alias_for_on_duplicate_key = False # Only available ... + # ... in MySQL 8+ + # MySQL doesn't support "DEFAULT VALUES" but *does* support # "VALUES (DEFAULT)" supports_default_values = False @@ -2783,6 +2807,10 @@ class MySQLDialect(default.DefaultDialect): self.is_mariadb and self.server_version_info >= (10, 5) ) + self._requires_alias_for_on_duplicate_key = ( + self._is_mysql and self.server_version_info >= (8, 0, 20) + ) + self._warn_for_known_db_issues() def _warn_for_known_db_issues(self): diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 414f73ad76..52d4529aec 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -63,6 +63,7 @@ from sqlalchemy.testing import eq_ignore_whitespace from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock +from sqlalchemy.testing import Variation class ReservedWordFixture(AssertsCompiledSQL): @@ -1100,18 +1101,34 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL): bar=stmt.inserted.bar, baz=stmt.inserted.baz ) - def test_from_values(self): + @testing.variation("version", ["mysql8", "all_others"]) + def test_from_values(self, version: Variation): stmt = insert(self.table).values( [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}] ) stmt = stmt.on_duplicate_key_update( bar=stmt.inserted.bar, baz=stmt.inserted.baz ) - expected_sql = ( - "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) " - "ON DUPLICATE KEY UPDATE bar = VALUES(bar), baz = VALUES(baz)" - ) - self.assert_compile(stmt, expected_sql) + + if version.all_others: + expected_sql = ( + "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) " + "ON DUPLICATE KEY UPDATE bar = VALUES(bar), baz = VALUES(baz)" + ) + dialect = None + elif version.mysql8: + expected_sql = ( + "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) " + "AS new ON DUPLICATE KEY UPDATE " + "bar = new.bar, " + "baz = new.baz" + ) + dialect = mysql.dialect() + dialect._requires_alias_for_on_duplicate_key = True + else: + version.fail() + + self.assert_compile(stmt, expected_sql, dialect=dialect) def test_from_literal(self): stmt = insert(self.table).values( @@ -1135,7 +1152,8 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL): ) self.assert_compile(stmt, expected_sql) - def test_update_sql_expr(self): + @testing.variation("version", ["mysql8", "all_others"]) + def test_update_sql_expr(self, version: Variation): stmt = insert(self.table).values( [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}] ) @@ -1143,11 +1161,60 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL): bar=func.coalesce(stmt.inserted.bar), baz=stmt.inserted.baz + "some literal" + stmt.inserted.bar, ) + + if version.all_others: + expected_sql = ( + "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) ON " + "DUPLICATE KEY UPDATE bar = coalesce(VALUES(bar)), " + "baz = (concat(VALUES(baz), %s, VALUES(bar)))" + ) + dialect = None + elif version.mysql8: + + expected_sql = ( + "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) " + "AS new ON DUPLICATE KEY UPDATE bar = " + "coalesce(new.bar), " + "baz = (concat(new.baz, %s, " + "new.bar))" + ) + dialect = mysql.dialect() + dialect._requires_alias_for_on_duplicate_key = True + else: + version.fail() + + self.assert_compile( + stmt, + expected_sql, + checkparams={ + "id_m0": 1, + "bar_m0": "ab", + "id_m1": 2, + "bar_m1": "b", + "baz_1": "some literal", + }, + dialect=dialect, + ) + + def test_mysql8_on_update_dont_dup_alias_name(self): + t = table("new", column("id"), column("bar"), column("baz")) + stmt = insert(t).values( + [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}] + ) + stmt = stmt.on_duplicate_key_update( + bar=func.coalesce(stmt.inserted.bar), + baz=stmt.inserted.baz + "some literal" + stmt.inserted.bar, + ) + expected_sql = ( - "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) ON " - "DUPLICATE KEY UPDATE bar = coalesce(VALUES(bar)), " - "baz = (concat(VALUES(baz), %s, VALUES(bar)))" + "INSERT INTO new (id, bar) VALUES (%s, %s), (%s, %s) " + "AS new_1 ON DUPLICATE KEY UPDATE bar = " + "coalesce(new_1.bar), " + "baz = (concat(new_1.baz, %s, " + "new_1.bar))" ) + dialect = mysql.dialect() + dialect._requires_alias_for_on_duplicate_key = True self.assert_compile( stmt, expected_sql, @@ -1158,6 +1225,7 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL): "bar_m1": "b", "baz_1": "some literal", }, + dialect=dialect, ) -- 2.47.3