]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixes: #12117 Unable to use InstrumentedAttribute to value mappings in mysql/mariadb...
authorMingyu Park <mingyuu.dev@gmail.com>
Fri, 31 Jan 2025 13:09:25 +0000 (22:09 +0900)
committerMingyu Park <mingyuu.dev@gmail.com>
Fri, 31 Jan 2025 13:09:25 +0000 (22:09 +0900)
lib/sqlalchemy/dialects/mysql/base.py
test/dialect/mysql/test_compiler.py

index 71a4a4b6666b30797ef77179ff68611db2a1f4f6..79c3d0982fbb9d82d135fb845b0e0acf9892cc75 100644 (file)
@@ -1097,6 +1097,7 @@ from ...engine import cursor as _cursor
 from ...engine import default
 from ...engine import reflection
 from ...engine.reflection import ReflectionDefaults
+from ...orm.attributes import InstrumentedAttribute
 from ...sql import coercions
 from ...sql import compiler
 from ...sql import elements
@@ -1401,9 +1402,14 @@ class MySQLCompiler(compiler.SQLCompiler):
             else:
                 _on_dup_alias_name = "new"
 
+        on_duplicate_update = {
+            (key.key if isinstance(key, InstrumentedAttribute) else key): value
+            for key, value in on_duplicate.update.items()
+        }
+        
         # traverses through all table columns to preserve table column order
-        for column in (col for col in cols if col.key in on_duplicate.update):
-            val = on_duplicate.update[column.key]
+        for column in (col for col in cols if col.key in on_duplicate_update):
+            val = on_duplicate_update[column.key]
 
             # TODO: this coercion should be up front.  we can't cache
             # SQL constructs with non-bound literals buried in them
@@ -1444,7 +1450,7 @@ class MySQLCompiler(compiler.SQLCompiler):
             name_text = self.preparer.quote(column.name)
             clauses.append("%s = %s" % (name_text, value_text))
 
-        non_matching = set(on_duplicate.update) - {c.key for c in cols}
+        non_matching = set(on_duplicate_update) - {c.key for c in cols}
         if non_matching:
             util.warn(
                 "Additional column names not matching "
index 59d604eace1cded3b109fb1ef6887b36fb373ef7..b2eb6033dc855d6685a955036b373d5a170e4abe 100644 (file)
@@ -54,6 +54,9 @@ from sqlalchemy import VARCHAR
 from sqlalchemy.dialects.mysql import base as mysql
 from sqlalchemy.dialects.mysql import insert
 from sqlalchemy.dialects.mysql import match
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.sql import column
 from sqlalchemy.sql import delete
 from sqlalchemy.sql import table
@@ -1343,7 +1346,21 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
             },
             dialect=dialect,
         )
+    
+    def test_on_update_instrumented_attribute_dict(self):
+        class Base(DeclarativeBase):
+            pass
+        class T(Base):
+            __tablename__ = "table"
 
+            foo: Mapped[int] = mapped_column(Integer, primary_key=True)
+
+        q = insert(T).values(foo=1).on_duplicate_key_update({T.foo: 2})
+        self.assert_compile(
+            q, 
+            "INSERT INTO `table` (foo) VALUES (%s) ON DUPLICATE KEY UPDATE foo = %s",
+            {"foo": 1, "param_1": 2}
+        )
 
 class RegexpCommon(testing.AssertsCompiledSQL):
     def setup_test(self):