]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Accept ColumnCollection in update_on_conflict(set_=
authorGord Thompson <gord@gordthompson.com>
Sat, 13 Feb 2021 21:43:21 +0000 (14:43 -0700)
committerGord Thompson <gord@gordthompson.com>
Mon, 15 Feb 2021 18:16:38 +0000 (11:16 -0700)
Fixes: #5939
Change-Id: I21d7125765028e2a98d5ef4c32d8e7e457aa2d12

doc/build/changelog/unreleased_14/5939.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/dml.py
lib/sqlalchemy/dialects/postgresql/dml.py
lib/sqlalchemy/dialects/sqlite/dml.py
lib/sqlalchemy/sql/functions.py
test/dialect/postgresql/test_on_conflict.py
test/dialect/test_sqlite.py

diff --git a/doc/build/changelog/unreleased_14/5939.rst b/doc/build/changelog/unreleased_14/5939.rst
new file mode 100644 (file)
index 0000000..2552cb2
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: sql, usecase, postgresql, sqlite
+    :tickets: 5939
+
+    Enhance ``set_`` keyword of :class:`.OnConflictDoUpdate` to accept a
+    :class:`.ColumnCollection`, such as the ``.c.`` collection from a
+    :class:`Selectable`, or the ``.excluded`` contextual object.
index 6c50dcca9166ba8bf0d24435c5a5e77f3036d1b4..d57a8909087cb9ea0507d0ad2f1af6d26796fd1d 100644 (file)
@@ -2,6 +2,7 @@ from ... import exc
 from ... import util
 from ...sql.base import _exclusive_against
 from ...sql.base import _generative
+from ...sql.base import ColumnCollection
 from ...sql.dml import Insert as StandardInsert
 from ...sql.elements import ClauseElement
 from ...sql.expression import alias
@@ -145,6 +146,17 @@ class OnDuplicateClause(ClauseElement):
             self._parameter_ordering = [key for key, value in update]
             update = dict(update)
 
-        if not update or not isinstance(update, dict):
-            raise ValueError("update parameter must be a non-empty dictionary")
+        if isinstance(update, dict):
+            if not update:
+                raise ValueError(
+                    "update parameter dictionary must not be empty"
+                )
+        elif isinstance(update, ColumnCollection):
+            update = dict(update)
+        else:
+            raise ValueError(
+                "update parameter must be a non-empty dictionary "
+                "or a ColumnCollection such as the `.c.` collection "
+                "of a Table object"
+            )
         self.update = update
index bff61e173674f4e388ac743ba9c7b74527bb99c1..b6f5cdf7e04156a5488a46feef110a983fefcabc 100644 (file)
@@ -12,6 +12,7 @@ from ...sql import roles
 from ...sql import schema
 from ...sql.base import _exclusive_against
 from ...sql.base import _generative
+from ...sql.base import ColumnCollection
 from ...sql.dml import Insert as StandardInsert
 from ...sql.elements import ClauseElement
 from ...sql.expression import alias
@@ -243,8 +244,17 @@ class OnConflictDoUpdate(OnConflictClause):
                 "but not both, must be specified unless DO NOTHING"
             )
 
-        if not isinstance(set_, dict) or not set_:
-            raise ValueError("set parameter must be a non-empty dictionary")
+        if isinstance(set_, dict):
+            if not set_:
+                raise ValueError("set parameter dictionary must not be empty")
+        elif isinstance(set_, ColumnCollection):
+            set_ = dict(set_)
+        else:
+            raise ValueError(
+                "set parameter must be a non-empty dictionary "
+                "or a ColumnCollection such as the `.c.` collection "
+                "of a Table object"
+            )
         self.update_values_to_set = [
             (coercions.expect(roles.DMLColumnRole, key), value)
             for key, value in set_.items()
index be32781c7a643f983a5f9e470805ce522cd6838c..4cb819960ad8b42a8ebadd62a635e4c4e480989b 100644 (file)
@@ -9,6 +9,7 @@ from ...sql import coercions
 from ...sql import roles
 from ...sql.base import _exclusive_against
 from ...sql.base import _generative
+from ...sql.base import ColumnCollection
 from ...sql.dml import Insert as StandardInsert
 from ...sql.elements import ClauseElement
 from ...sql.expression import alias
@@ -169,8 +170,17 @@ class OnConflictDoUpdate(OnConflictClause):
             index_where=index_where,
         )
 
-        if not isinstance(set_, dict) or not set_:
-            raise ValueError("set parameter must be a non-empty dictionary")
+        if isinstance(set_, dict):
+            if not set_:
+                raise ValueError("set parameter dictionary must not be empty")
+        elif isinstance(set_, ColumnCollection):
+            set_ = dict(set_)
+        else:
+            raise ValueError(
+                "set parameter must be a non-empty dictionary "
+                "or a ColumnCollection such as the `.c.` collection "
+                "of a Table object"
+            )
         self.update_values_to_set = [
             (coercions.expect(roles.DMLColumnRole, key), value)
             for key, value in set_.items()
index 641715327bff66a7cf91efeb58271bfcac22e759..40af73d7a26653a29f442d3e6fa662224d67c7d0 100644 (file)
@@ -163,7 +163,7 @@ class FunctionElement(Executable, ColumnElement, FromClause, Generative):
         return ScalarFunctionColumn(self, name, type_)
 
     def table_valued(self, *expr, **kw):
-        """Return a :class:`_sql.TableValuedAlias` representation of this
+        r"""Return a :class:`_sql.TableValuedAlias` representation of this
         :class:`_functions.FunctionElement` with table-valued expressions added.
 
         e.g.::
index 4e96cc6a217af1e67cb8b2b6b693d3c8d57dce55..489084de79afaf56ebe6827f9f8856d1368e0e7c 100644 (file)
@@ -176,14 +176,21 @@ class OnConflictTest(fixtures.TablesTest):
             [(1, "name1")],
         )
 
-    def test_on_conflict_do_update_one(self, connection):
+    @testing.combinations(
+        ("with_dict", True),
+        ("issue_5939", False),
+        id_="ia",
+        argnames="with_dict",
+    )
+    def test_on_conflict_do_update_one(self, connection, with_dict):
         users = self.tables.users
 
         connection.execute(users.insert(), dict(id=1, name="name1"))
 
         i = insert(users)
         i = i.on_conflict_do_update(
-            index_elements=[users.c.id], set_=dict(name=i.excluded.name)
+            index_elements=[users.c.id],
+            set_=dict(name=i.excluded.name) if with_dict else i.excluded,
         )
         result = connection.execute(i, dict(id=1, name="name1"))
 
index ad169eebf9bf5c7717f6b0752c2baf5643cad8d6..aee97e8c62315d1fb3a808e0a26857b9620d4564 100644 (file)
@@ -2810,7 +2810,13 @@ class OnConflictTest(fixtures.TablesTest):
             [(1, "name1")],
         )
 
-    def test_on_conflict_do_update_one(self, connection):
+    @testing.combinations(
+        ("with_dict", True),
+        ("issue_5939", False),
+        id_="ia",
+        argnames="with_dict",
+    )
+    def test_on_conflict_do_update_one(self, connection, with_dict):
         users = self.tables.users
 
         conn = connection
@@ -2818,7 +2824,8 @@ class OnConflictTest(fixtures.TablesTest):
 
         i = insert(users)
         i = i.on_conflict_do_update(
-            index_elements=[users.c.id], set_=dict(name=i.excluded.name)
+            index_elements=[users.c.id],
+            set_=dict(name=i.excluded.name) if with_dict else i.excluded,
         )
         result = conn.execute(i, dict(id=1, name="name1"))