From 7d9a2ae3057f0707832699b617c0e44e6d7780ec Mon Sep 17 00:00:00 2001 From: Mingyu Park Date: Fri, 31 Jan 2025 22:09:25 +0900 Subject: [PATCH] Fixes: #12117 Unable to use InstrumentedAttribute to value mappings in mysql/mariadb on_duplicate_key_update (https://github.com/sqlalchemy/sqlalchemy/issues/12117) --- lib/sqlalchemy/dialects/mysql/base.py | 12 +++++++++--- test/dialect/mysql/test_compiler.py | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 71a4a4b666..79c3d0982f 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1097,6 +1097,7 @@ from ...engine import cursor as _cursor from ...engine import default from ...engine import reflection from ...engine.reflection import ReflectionDefaults +from ...orm.attributes import InstrumentedAttribute from ...sql import coercions from ...sql import compiler from ...sql import elements @@ -1401,9 +1402,14 @@ class MySQLCompiler(compiler.SQLCompiler): else: _on_dup_alias_name = "new" + on_duplicate_update = { + (key.key if isinstance(key, InstrumentedAttribute) else key): value + for key, value in on_duplicate.update.items() + } + # traverses through all table columns to preserve table column order - for column in (col for col in cols if col.key in on_duplicate.update): - val = on_duplicate.update[column.key] + for column in (col for col in cols if col.key in on_duplicate_update): + val = on_duplicate_update[column.key] # TODO: this coercion should be up front. we can't cache # SQL constructs with non-bound literals buried in them @@ -1444,7 +1450,7 @@ class MySQLCompiler(compiler.SQLCompiler): name_text = self.preparer.quote(column.name) clauses.append("%s = %s" % (name_text, value_text)) - non_matching = set(on_duplicate.update) - {c.key for c in cols} + non_matching = set(on_duplicate_update) - {c.key for c in cols} if non_matching: util.warn( "Additional column names not matching " diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 59d604eace..b2eb6033dc 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -54,6 +54,9 @@ from sqlalchemy import VARCHAR from sqlalchemy.dialects.mysql import base as mysql from sqlalchemy.dialects.mysql import insert from sqlalchemy.dialects.mysql import match +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.sql import column from sqlalchemy.sql import delete from sqlalchemy.sql import table @@ -1343,7 +1346,21 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL): }, dialect=dialect, ) + + def test_on_update_instrumented_attribute_dict(self): + class Base(DeclarativeBase): + pass + class T(Base): + __tablename__ = "table" + foo: Mapped[int] = mapped_column(Integer, primary_key=True) + + q = insert(T).values(foo=1).on_duplicate_key_update({T.foo: 2}) + self.assert_compile( + q, + "INSERT INTO `table` (foo) VALUES (%s) ON DUPLICATE KEY UPDATE foo = %s", + {"foo": 1, "param_1": 2} + ) class RegexpCommon(testing.AssertsCompiledSQL): def setup_test(self): -- 2.47.3