From 57b2aae0d9efe91c2338e5a762e04366f86c2651 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 4 Mar 2020 17:44:40 -0500 Subject: [PATCH] Render VALUES within composed MySQL on duplicate key expressions Fixed issue in MySQL :meth:`.mysql.Insert.on_duplicate_key_update` construct where using a SQL function or other composed expression for a column argument would not properly render the ``VALUES`` keyword surrounding the column itself. Fixes: #5173 Change-Id: I16d39c2fdb8bbb7f3d1b2ffdd20e1bf69359ab75 --- doc/build/changelog/unreleased_13/5173.rst | 8 +++++ lib/sqlalchemy/dialects/mysql/base.py | 34 ++++++++++++++++------ test/dialect/mysql/test_compiler.py | 25 ++++++++++++++++ test/dialect/mysql/test_on_duplicate.py | 17 +++++++++++ 4 files changed, 75 insertions(+), 9 deletions(-) create mode 100644 doc/build/changelog/unreleased_13/5173.rst diff --git a/doc/build/changelog/unreleased_13/5173.rst b/doc/build/changelog/unreleased_13/5173.rst new file mode 100644 index 0000000000..15e4fa2601 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5173.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, mysql + :tickets: 5173 + + Fixed issue in MySQL :meth:`.mysql.Insert.on_duplicate_key_update` construct + where using a SQL function or other composed expression for a column argument + would not properly render the ``VALUES`` keyword surrounding the column + itself. diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 6ea8cbcb81..a977195321 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -743,6 +743,8 @@ from collections import defaultdict import re import sys +from sqlalchemy import literal_column +from sqlalchemy.sql import visitors from . import reflection as _reflection from .enumerated import ENUM from .enumerated import SET @@ -1303,17 +1305,31 @@ class MySQLCompiler(compiler.SQLCompiler): if coercions._is_literal(val): val = elements.BindParameter(None, val, type_=column.type) value_text = self.process(val.self_group(), use_schema=False) - elif isinstance(val, elements.BindParameter) and val.type._isnull: - val = val._clone() - val.type = column.type - value_text = self.process(val.self_group(), use_schema=False) - elif ( - isinstance(val, elements.ColumnClause) - and val.table is on_duplicate.inserted_alias - ): - value_text = "VALUES(" + self.preparer.quote(column.name) + ")" else: + + def replace(obj): + if ( + isinstance(obj, elements.BindParameter) + and obj.type._isnull + ): + obj = obj._clone() + obj.type = column.type + return obj + elif ( + isinstance(obj, elements.ColumnClause) + and obj.table is on_duplicate.inserted_alias + ): + obj = literal_column( + "VALUES(" + self.preparer.quote(column.name) + ")" + ) + return obj + else: + # element is not replaced + return None + + val = visitors.replacement_traverse(val, {}, replace) value_text = self.process(val.self_group(), use_schema=False) + name_text = self.preparer.quote(column.name) clauses.append("%s = %s" % (name_text, value_text)) diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index d59c0549f1..e74c37d63d 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -928,3 +928,28 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL): "ON DUPLICATE KEY UPDATE bar = %s" ) self.assert_compile(stmt, expected_sql) + + def test_update_sql_expr(self): + stmt = insert(self.table).values( + [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}] + ) + stmt = stmt.on_duplicate_key_update( + bar=func.coalesce(stmt.inserted.bar), + baz=stmt.inserted.baz + "some literal", + ) + expected_sql = ( + "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) ON " + "DUPLICATE KEY UPDATE bar = coalesce(VALUES(bar)), " + "baz = (concat(VALUES(baz), %s))" + ) + self.assert_compile( + stmt, + expected_sql, + checkparams={ + "id_m0": 1, + "bar_m0": "ab", + "id_m1": 2, + "bar_m1": "b", + "baz_1": "some literal", + }, + ) diff --git a/test/dialect/mysql/test_on_duplicate.py b/test/dialect/mysql/test_on_duplicate.py index 077e5ba98c..4396481536 100644 --- a/test/dialect/mysql/test_on_duplicate.py +++ b/test/dialect/mysql/test_on_duplicate.py @@ -77,6 +77,23 @@ class OnDuplicateTest(fixtures.TablesTest): [(1, "b", "bz", None)], ) + def test_on_duplicate_key_update_expression(self): + foos = self.tables.foos + with testing.db.connect() as conn: + conn.execute(insert(foos, dict(id=1, bar="b", baz="bz"))) + stmt = insert(foos).values( + [dict(id=1, bar="ab"), dict(id=2, bar="b")] + ) + stmt = stmt.on_duplicate_key_update( + bar=func.concat(stmt.inserted.bar, "_foo") + ) + result = conn.execute(stmt) + eq_(result.inserted_primary_key, [2]) + eq_( + conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), + [(1, "ab_foo", "bz", False)], + ) + def test_on_duplicate_key_update_preserve_order(self): foos = self.tables.foos with testing.db.connect() as conn: -- 2.39.5