]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support Column objects in the SET clause for upsert
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Nov 2020 01:13:20 +0000 (20:13 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Nov 2020 01:13:20 +0000 (20:13 -0500)
Established support for :class:`_schema.Column` objects as well as ORM
instrumented attributes as keys in the ``set_`` dictionary passed to the
:meth:`_postgresql.Insert.on_conflict_do_update` and
:meth:`_sqlite.Insert.on_conflict_do_update` methods, which match to the
:class:`_schema.Column` objects in the ``.c`` collection of the target
:class:`_schema.Table`. Previously,  only string column names were
expected; a column expression would be assumed to be an out-of-table
expression that would render fully along with a warning.

Fixes: #5722
Change-Id: Ice73b501d721c28d978a0277a83cedc6aff756a9

doc/build/changelog/unreleased_14/5722.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/dml.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/dialects/sqlite/dml.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_on_conflict.py
test/dialect/test_sqlite.py

diff --git a/doc/build/changelog/unreleased_14/5722.rst b/doc/build/changelog/unreleased_14/5722.rst
new file mode 100644 (file)
index 0000000..e756f8e
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 5722
+    :versions: 1.4.0b2
+
+    Established support for :class:`_schema.Column` objects as well as ORM
+    instrumented attributes as keys in the ``set_`` dictionary passed to the
+    :meth:`_postgresql.Insert.on_conflict_do_update` and
+    :meth:`_sqlite.Insert.on_conflict_do_update` methods, which match to the
+    :class:`_schema.Column` objects in the ``.c`` collection of the target
+    :class:`_schema.Table`. Previously,  only string column names were
+    expected; a column expression would be assumed to be an out-of-table
+    expression that would render fully along with a warning.
\ No newline at end of file
index 3c33d9ee8ec5964765eb6ec7cd4bc6e97273afa7..3a458ebed471598b6bae7b8c39d9bddc183a82d1 100644 (file)
@@ -2147,22 +2147,28 @@ class PGCompiler(compiler.SQLCompiler):
         cols = insert_statement.table.c
         for c in cols:
             col_key = c.key
+
             if col_key in set_parameters:
                 value = set_parameters.pop(col_key)
-                if coercions._is_literal(value):
-                    value = elements.BindParameter(None, value, type_=c.type)
+            elif c in set_parameters:
+                value = set_parameters.pop(c)
+            else:
+                continue
 
-                else:
-                    if (
-                        isinstance(value, elements.BindParameter)
-                        and value.type._isnull
-                    ):
-                        value = value._clone()
-                        value.type = c.type
-                value_text = self.process(value.self_group(), use_schema=False)
-
-                key_text = self.preparer.quote(col_key)
-                action_set_ops.append("%s = %s" % (key_text, value_text))
+            if coercions._is_literal(value):
+                value = elements.BindParameter(None, value, type_=c.type)
+
+            else:
+                if (
+                    isinstance(value, elements.BindParameter)
+                    and value.type._isnull
+                ):
+                    value = value._clone()
+                    value.type = c.type
+            value_text = self.process(value.self_group(), use_schema=False)
+
+            key_text = self.preparer.quote(col_key)
+            action_set_ops.append("%s = %s" % (key_text, value_text))
 
         # check for names that don't match columns
         if set_parameters:
index 50fd095287a0f86910c7faf93ed231c0cc95ac3a..78cad974fc55f82f77f580f8a62b8bc81b47debe 100644 (file)
@@ -7,6 +7,8 @@
 
 from . import ext
 from ... import util
+from ...sql import coercions
+from ...sql import roles
 from ...sql import schema
 from ...sql.base import _generative
 from ...sql.dml import Insert as StandardInsert
@@ -77,12 +79,16 @@ class Insert(StandardInsert):
          conditional target index.
 
         :param set\_:
-         Required argument. A dictionary or other mapping object
-         with column names as keys and expressions or literals as values,
-         specifying the ``SET`` actions to take.
-         If the target :class:`_schema.Column` specifies a ".
-         key" attribute distinct
-         from the column name, that key should be used.
+         A dictionary or other mapping object
+         where the keys are either names of columns in the target table,
+         or :class:`_schema.Column` objects or other ORM-mapped columns
+         matching that of the target table, and expressions or literals
+         as values, specifying the ``SET`` actions to take.
+
+         .. versionadded:: 1.4 The
+            :paramref:`_postgresql.Insert.on_conflict_do_update.set_`
+            parameter supports :class:`_schema.Column` objects from the target
+            :class:`_schema.Table` as keys.
 
          .. warning:: This dictionary does **not** take into account
             Python-specified default UPDATE values or generation functions,
@@ -229,6 +235,7 @@ class OnConflictDoUpdate(OnConflictClause):
         if not isinstance(set_, dict) or not set_:
             raise ValueError("set parameter must be a non-empty dictionary")
         self.update_values_to_set = [
-            (key, value) for key, value in set_.items()
+            (coercions.expect(roles.DMLColumnRole, key), value)
+            for key, value in set_.items()
         ]
         self.update_whereclause = where
index 404a215b629df9de88651b3be603a2fe6a4ebef2..7c1bbb18ef26b1f194919ff0671db426e5f26b73 100644 (file)
@@ -1355,22 +1355,28 @@ class SQLiteCompiler(compiler.SQLCompiler):
         cols = insert_statement.table.c
         for c in cols:
             col_key = c.key
+
             if col_key in set_parameters:
                 value = set_parameters.pop(col_key)
-                if coercions._is_literal(value):
-                    value = elements.BindParameter(None, value, type_=c.type)
+            elif c in set_parameters:
+                value = set_parameters.pop(c)
+            else:
+                continue
 
-                else:
-                    if (
-                        isinstance(value, elements.BindParameter)
-                        and value.type._isnull
-                    ):
-                        value = value._clone()
-                        value.type = c.type
-                value_text = self.process(value.self_group(), use_schema=False)
-
-                key_text = self.preparer.quote(col_key)
-                action_set_ops.append("%s = %s" % (key_text, value_text))
+            if coercions._is_literal(value):
+                value = elements.BindParameter(None, value, type_=c.type)
+
+            else:
+                if (
+                    isinstance(value, elements.BindParameter)
+                    and value.type._isnull
+                ):
+                    value = value._clone()
+                    value.type = c.type
+            value_text = self.process(value.self_group(), use_schema=False)
+
+            key_text = self.preparer.quote(col_key)
+            action_set_ops.append("%s = %s" % (key_text, value_text))
 
         # check for names that don't match columns
         if set_parameters:
index a4d4d560cf67565df08a978e8be5269fe315d138..2d7ea6e4a89b3645a0224478089fc8465bff2f88 100644 (file)
@@ -5,6 +5,8 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 from ... import util
+from ...sql import coercions
+from ...sql import roles
 from ...sql.base import _generative
 from ...sql.dml import Insert as StandardInsert
 from ...sql.elements import ClauseElement
@@ -65,12 +67,16 @@ class Insert(StandardInsert):
          conditional target index.
 
         :param set\_:
-         Required argument. A dictionary or other mapping object
-         with column names as keys and expressions or literals as values,
-         specifying the ``SET`` actions to take.
-         If the target :class:`_schema.Column` specifies a ".
-         key" attribute distinct
-         from the column name, that key should be used.
+         A dictionary or other mapping object
+         where the keys are either names of columns in the target table,
+         or :class:`_schema.Column` objects or other ORM-mapped columns
+         matching that of the target table, and expressions or literals
+         as values, specifying the ``SET`` actions to take.
+
+         .. versionadded:: 1.4 The
+            :paramref:`_sqlite.Insert.on_conflict_do_update.set_`
+            parameter supports :class:`_schema.Column` objects from the target
+            :class:`_schema.Table` as keys.
 
          .. warning:: This dictionary does **not** take into account
             Python-specified default UPDATE values or generation functions,
@@ -155,6 +161,7 @@ class OnConflictDoUpdate(OnConflictClause):
         if not isinstance(set_, dict) or not set_:
             raise ValueError("set parameter must be a non-empty dictionary")
         self.update_values_to_set = [
-            (key, value) for key, value in set_.items()
+            (coercions.expect(roles.DMLColumnRole, key), value)
+            for key, value in set_.items()
         ]
         self.update_whereclause = where
index a031c3df93ea417f681ddc6c62dbb83b56fbb1d5..9651f7bd9dee51b494acbd3d68cc850f3fa3b8d9 100644 (file)
@@ -1864,6 +1864,31 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL):
             },
         )
 
+    def test_do_update_set_clause_column_keys(self):
+        i = insert(self.table_with_metadata).values(myid=1, name="foo")
+        i = i.on_conflict_do_update(
+            index_elements=["myid"],
+            set_=OrderedDict(
+                [
+                    (self.table_with_metadata.c.name, "I'm a name"),
+                    (self.table_with_metadata.c.description, None),
+                ]
+            ),
+        )
+        self.assert_compile(
+            i,
+            "INSERT INTO mytable (myid, name) VALUES "
+            "(%(myid)s, %(name)s) ON CONFLICT (myid) "
+            "DO UPDATE SET name = %(param_1)s, "
+            "description = %(param_2)s",
+            {
+                "myid": 1,
+                "name": "foo",
+                "param_1": "I'm a name",
+                "param_2": None,
+            },
+        )
+
     def test_do_update_set_clause_literal(self):
         i = insert(self.table_with_metadata).values(myid=1, name="foo")
         i = i.on_conflict_do_update(
index 7a9bfd75dacba3a045ae5b71655bb0a32ae085bb..76048784264004f1ac993e5c5977d6310a5c61cd 100644 (file)
@@ -10,6 +10,7 @@ from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import types as sqltypes
 from sqlalchemy.dialects.postgresql import insert
+from sqlalchemy.testing import config
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import assert_raises
 from sqlalchemy.testing.assertions import eq_
@@ -30,6 +31,14 @@ class OnConflictTest(fixtures.TablesTest):
             Column("name", String(50)),
         )
 
+        Table(
+            "users_schema",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50)),
+            schema=config.test_schema,
+        )
+
         class SpecialType(sqltypes.TypeDecorator):
             impl = String
 
@@ -185,6 +194,99 @@ class OnConflictTest(fixtures.TablesTest):
                 [(1, "name1")],
             )
 
+    def test_on_conflict_do_update_schema(self):
+        users = self.tables.get("%s.users_schema" % config.test_schema)
+
+        with testing.db.connect() as conn:
+            conn.execute(users.insert(), dict(id=1, name="name1"))
+
+            i = insert(users)
+            i = i.on_conflict_do_update(
+                index_elements=[users.c.id], set_=dict(name=i.excluded.name)
+            )
+            result = conn.execute(i, dict(id=1, name="name1"))
+
+            eq_(result.inserted_primary_key, (1,))
+            eq_(result.returned_defaults, None)
+
+            eq_(
+                conn.execute(users.select().where(users.c.id == 1)).fetchall(),
+                [(1, "name1")],
+            )
+
+    def test_on_conflict_do_update_column_as_key_set(self):
+        users = self.tables.users
+
+        with testing.db.connect() as conn:
+            conn.execute(users.insert(), dict(id=1, name="name1"))
+
+            i = insert(users)
+            i = i.on_conflict_do_update(
+                index_elements=[users.c.id],
+                set_={users.c.name: i.excluded.name},
+            )
+            result = conn.execute(i, dict(id=1, name="name1"))
+
+            eq_(result.inserted_primary_key, (1,))
+            eq_(result.returned_defaults, None)
+
+            eq_(
+                conn.execute(users.select().where(users.c.id == 1)).fetchall(),
+                [(1, "name1")],
+            )
+
+    def test_on_conflict_do_update_clauseelem_as_key_set(self):
+        users = self.tables.users
+
+        class MyElem(object):
+            def __init__(self, expr):
+                self.expr = expr
+
+            def __clause_element__(self):
+                return self.expr
+
+        with testing.db.connect() as conn:
+            conn.execute(
+                users.insert(),
+                {"id": 1, "name": "name1"},
+            )
+
+            i = insert(users)
+            i = i.on_conflict_do_update(
+                index_elements=[users.c.id],
+                set_={MyElem(users.c.name): i.excluded.name},
+            ).values({MyElem(users.c.id): 1, MyElem(users.c.name): "name1"})
+            result = conn.execute(i)
+
+            eq_(result.inserted_primary_key, (1,))
+            eq_(result.returned_defaults, None)
+
+            eq_(
+                conn.execute(users.select().where(users.c.id == 1)).fetchall(),
+                [(1, "name1")],
+            )
+
+    def test_on_conflict_do_update_column_as_key_set_schema(self):
+        users = self.tables.get("%s.users_schema" % config.test_schema)
+
+        with testing.db.connect() as conn:
+            conn.execute(users.insert(), dict(id=1, name="name1"))
+
+            i = insert(users)
+            i = i.on_conflict_do_update(
+                index_elements=[users.c.id],
+                set_={users.c.name: i.excluded.name},
+            )
+            result = conn.execute(i, dict(id=1, name="name1"))
+
+            eq_(result.inserted_primary_key, (1,))
+            eq_(result.returned_defaults, None)
+
+            eq_(
+                conn.execute(users.select().where(users.c.id == 1)).fetchall(),
+                [(1, "name1")],
+            )
+
     def test_on_conflict_do_update_two(self):
         users = self.tables.users
 
index 456bad7bdfe9e5a66975aa30ec50bc3243d345db..f8b50f8883979e8f32088372028ba97e8b6aaa53 100644 (file)
@@ -2921,6 +2921,53 @@ class OnConflictTest(fixtures.TablesTest):
             [(10, "I'm a name")],
         )
 
+    def test_on_conflict_do_update_column_keys(self, connection):
+        users = self.tables.users
+
+        conn = connection
+        conn.execute(users.insert(), dict(id=1, name="name1"))
+
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=users.primary_key.columns,
+            set_={users.c.id: 10, users.c.name: "I'm a name"},
+        ).values(id=1, name="name4")
+
+        result = conn.execute(i)
+        eq_(result.inserted_primary_key, (1,))
+
+        eq_(
+            conn.execute(users.select().where(users.c.id == 10)).fetchall(),
+            [(10, "I'm a name")],
+        )
+
+    def test_on_conflict_do_update_clauseelem_keys(self, connection):
+        users = self.tables.users
+
+        class MyElem(object):
+            def __init__(self, expr):
+                self.expr = expr
+
+            def __clause_element__(self):
+                return self.expr
+
+        conn = connection
+        conn.execute(users.insert(), dict(id=1, name="name1"))
+
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=users.primary_key.columns,
+            set_={MyElem(users.c.id): 10, MyElem(users.c.name): "I'm a name"},
+        ).values({MyElem(users.c.id): 1, MyElem(users.c.name): "name4"})
+
+        result = conn.execute(i)
+        eq_(result.inserted_primary_key, (1,))
+
+        eq_(
+            conn.execute(users.select().where(users.c.id == 10)).fetchall(),
+            [(10, "I'm a name")],
+        )
+
     def test_on_conflict_do_update_multivalues(self, connection):
         users = self.tables.users