]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Unable to use InstrumentedAttribute to value mappings in mysql/mariadb on_duplicate_k...
authorMingyu Park <mingyuu.dev@gmail.com>
Sat, 1 Feb 2025 07:43:35 +0000 (02:43 -0500)
committerFederico Caselli <cfederico87@gmail.com>
Sat, 1 Feb 2025 18:15:47 +0000 (19:15 +0100)
Fixes: #12117
Closes: #12296
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12296
Pull-request-sha: 32a09ebd18a6f97fdb23cc8a8e212342e6c26291

Change-Id: I72701f63b13105e5dc36e63ba2651da2673f1735

doc/build/changelog/unreleased_20/12117.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
test/dialect/mysql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_20/12117.rst b/doc/build/changelog/unreleased_20/12117.rst
new file mode 100644 (file)
index 0000000..b4da4db
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, dml, mariadb, mysql
+    :tickets: 12117
+
+    Fixed a bug where the :class:`MySQLCompiler` would not properly compile statements
+    where :meth:`_mysql.Insert.on_duplicate_key_update` was passed values that included
+    :class:`InstrumentedAttribute` as keys.
+    Pull request courtesy of mingyu.
index 71a4a4b6666b30797ef77179ff68611db2a1f4f6..96eecc2ba67fa86a2d4d13e66503043475821e05 100644 (file)
@@ -1401,9 +1401,14 @@ class MySQLCompiler(compiler.SQLCompiler):
             else:
                 _on_dup_alias_name = "new"
 
+        on_duplicate_update = {
+            coercions.expect_as_key(roles.DMLColumnRole, key): value
+            for key, value in on_duplicate.update.items()
+        }
+
         # traverses through all table columns to preserve table column order
-        for column in (col for col in cols if col.key in on_duplicate.update):
-            val = on_duplicate.update[column.key]
+        for column in (col for col in cols if col.key in on_duplicate_update):
+            val = on_duplicate_update[column.key]
 
             # TODO: this coercion should be up front.  we can't cache
             # SQL constructs with non-bound literals buried in them
@@ -1444,7 +1449,7 @@ class MySQLCompiler(compiler.SQLCompiler):
             name_text = self.preparer.quote(column.name)
             clauses.append("%s = %s" % (name_text, value_text))
 
-        non_matching = set(on_duplicate.update) - {c.key for c in cols}
+        non_matching = set(on_duplicate_update) - {c.key for c in cols}
         if non_matching:
             util.warn(
                 "Additional column names not matching "
index 59d604eace1cded3b109fb1ef6887b36fb373ef7..8387d4e07c67ef7bf551642f789469ec62cfeb1f 100644 (file)
@@ -54,6 +54,9 @@ from sqlalchemy import VARCHAR
 from sqlalchemy.dialects.mysql import base as mysql
 from sqlalchemy.dialects.mysql import insert
 from sqlalchemy.dialects.mysql import match
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.sql import column
 from sqlalchemy.sql import delete
 from sqlalchemy.sql import table
@@ -1344,6 +1347,25 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=dialect,
         )
 
+    def test_on_update_instrumented_attribute_dict(self):
+        class Base(DeclarativeBase):
+            pass
+
+        class T(Base):
+            __tablename__ = "table"
+
+            foo: Mapped[int] = mapped_column(Integer, primary_key=True)
+
+        q = insert(T).values(foo=1).on_duplicate_key_update({T.foo: 2})
+        self.assert_compile(
+            q,
+            (
+                "INSERT INTO `table` (foo) VALUES (%s) "
+                "ON DUPLICATE KEY UPDATE foo = %s"
+            ),
+            {"foo": 1, "param_1": 2},
+        )
+
 
 class RegexpCommon(testing.AssertsCompiledSQL):
     def setup_test(self):