]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Render VALUES within composed MySQL on duplicate key expressions
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 4 Mar 2020 22:44:40 +0000 (17:44 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 4 Mar 2020 22:46:55 +0000 (17:46 -0500)
Fixed issue in MySQL :meth:`.mysql.Insert.on_duplicate_key_update` construct
where using a SQL function or other composed expression for a column argument
would not properly render the ``VALUES`` keyword surrounding the column
itself.

Fixes: #5173
Change-Id: I16d39c2fdb8bbb7f3d1b2ffdd20e1bf69359ab75
(cherry picked from commit 57b2aae0d9efe91c2338e5a762e04366f86c2651)

doc/build/changelog/unreleased_13/5173.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
test/dialect/mysql/test_compiler.py
test/dialect/mysql/test_on_duplicate.py

diff --git a/doc/build/changelog/unreleased_13/5173.rst b/doc/build/changelog/unreleased_13/5173.rst
new file mode 100644 (file)
index 0000000..15e4fa2
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, mysql
+    :tickets: 5173
+
+    Fixed issue in MySQL :meth:`.mysql.Insert.on_duplicate_key_update` construct
+    where using a SQL function or other composed expression for a column argument
+    would not properly render the ``VALUES`` keyword surrounding the column
+    itself.
index 5c2f114bddd5d1d16d6177894021dae971824bc6..82a5a6b2209ad9d791cb7fa63903f52b38266345 100644 (file)
@@ -743,6 +743,8 @@ from collections import defaultdict
 import re
 import sys
 
+from sqlalchemy import literal_column
+from sqlalchemy.sql import visitors
 from . import reflection as _reflection
 from .enumerated import ENUM
 from .enumerated import SET
@@ -1289,17 +1291,31 @@ class MySQLCompiler(compiler.SQLCompiler):
             if elements._is_literal(val):
                 val = elements.BindParameter(None, val, type_=column.type)
                 value_text = self.process(val.self_group(), use_schema=False)
-            elif isinstance(val, elements.BindParameter) and val.type._isnull:
-                val = val._clone()
-                val.type = column.type
-                value_text = self.process(val.self_group(), use_schema=False)
-            elif (
-                isinstance(val, elements.ColumnClause)
-                and val.table is on_duplicate.inserted_alias
-            ):
-                value_text = "VALUES(" + self.preparer.quote(column.name) + ")"
             else:
+
+                def replace(obj):
+                    if (
+                        isinstance(obj, elements.BindParameter)
+                        and obj.type._isnull
+                    ):
+                        obj = obj._clone()
+                        obj.type = column.type
+                        return obj
+                    elif (
+                        isinstance(obj, elements.ColumnClause)
+                        and obj.table is on_duplicate.inserted_alias
+                    ):
+                        obj = literal_column(
+                            "VALUES(" + self.preparer.quote(column.name) + ")"
+                        )
+                        return obj
+                    else:
+                        # element is not replaced
+                        return None
+
+                val = visitors.replacement_traverse(val, {}, replace)
                 value_text = self.process(val.self_group(), use_schema=False)
+
             name_text = self.preparer.quote(column.name)
             clauses.append("%s = %s" % (name_text, value_text))
 
index d59c0549f14c388e3f9ade90b14d71f832844aa8..e74c37d63da69d440bbcafca446eca0de9062712 100644 (file)
@@ -928,3 +928,28 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
             "ON DUPLICATE KEY UPDATE bar = %s"
         )
         self.assert_compile(stmt, expected_sql)
+
+    def test_update_sql_expr(self):
+        stmt = insert(self.table).values(
+            [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}]
+        )
+        stmt = stmt.on_duplicate_key_update(
+            bar=func.coalesce(stmt.inserted.bar),
+            baz=stmt.inserted.baz + "some literal",
+        )
+        expected_sql = (
+            "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) ON "
+            "DUPLICATE KEY UPDATE bar = coalesce(VALUES(bar)), "
+            "baz = (concat(VALUES(baz), %s))"
+        )
+        self.assert_compile(
+            stmt,
+            expected_sql,
+            checkparams={
+                "id_m0": 1,
+                "bar_m0": "ab",
+                "id_m1": 2,
+                "bar_m1": "b",
+                "baz_1": "some literal",
+            },
+        )
index 077e5ba98cc719ef96adebfe4f826460e51cfd11..43964815362e627a5c3f402f4a697d7d622ab308 100644 (file)
@@ -77,6 +77,23 @@ class OnDuplicateTest(fixtures.TablesTest):
                 [(1, "b", "bz", None)],
             )
 
+    def test_on_duplicate_key_update_expression(self):
+        foos = self.tables.foos
+        with testing.db.connect() as conn:
+            conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
+            stmt = insert(foos).values(
+                [dict(id=1, bar="ab"), dict(id=2, bar="b")]
+            )
+            stmt = stmt.on_duplicate_key_update(
+                bar=func.concat(stmt.inserted.bar, "_foo")
+            )
+            result = conn.execute(stmt)
+            eq_(result.inserted_primary_key, [2])
+            eq_(
+                conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
+                [(1, "ab_foo", "bz", False)],
+            )
+
     def test_on_duplicate_key_update_preserve_order(self):
         foos = self.tables.foos
         with testing.db.connect() as conn: