]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fix for: unable to set column to NULL within ON DUPLICATE KEY UPDATE 4716/head
authorLukas Banic <luko@lingea.cz>
Mon, 10 Jun 2019 13:11:41 +0000 (15:11 +0200)
committerLukas Banic <luko@lingea.cz>
Mon, 10 Jun 2019 13:11:41 +0000 (15:11 +0200)
lib/sqlalchemy/dialects/mysql/base.py
test/dialect/mysql/test_on_duplicate.py

index ad5ab288ce254f9ced7921d737805cdc115de145..fe9d882f7508e29a91943503b10458936fcde490 100644 (file)
@@ -1237,10 +1237,12 @@ class MySQLCompiler(compiler.SQLCompiler):
 
         clauses = []
         for column in cols:
-            val = on_duplicate.update.get(column.key)
-            if val is None:
+            try:
+                val = on_duplicate.update[column.key]
+            except KeyError:
                 continue
-            elif coercions._is_literal(val):
+
+            if coercions._is_literal(val):
                 val = elements.BindParameter(None, val, type_=column.type)
                 value_text = self.process(val.self_group(), use_schema=False)
             elif isinstance(val, elements.BindParameter) and val.type._isnull:
index 0c6f4792901ab60dae4ba8da183479d6dafd41de..077e5ba98cc719ef96adebfe4f826460e51cfd11 100644 (file)
@@ -62,6 +62,21 @@ class OnDuplicateTest(fixtures.TablesTest):
                 [(1, "ab", "bz", False)],
             )
 
+    def test_on_duplicate_key_update_null(self):
+        foos = self.tables.foos
+        with testing.db.connect() as conn:
+            conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
+            stmt = insert(foos).values(
+                [dict(id=1, bar="ab"), dict(id=2, bar="b")]
+            )
+            stmt = stmt.on_duplicate_key_update(updated_once=None)
+            result = conn.execute(stmt)
+            eq_(result.inserted_primary_key, [2])
+            eq_(
+                conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
+                [(1, "b", "bz", None)],
+            )
+
     def test_on_duplicate_key_update_preserve_order(self):
         foos = self.tables.foos
         with testing.db.connect() as conn: