]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Use full column->type processing for ON CONFLICT SET clause
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 13 Jan 2017 17:43:24 +0000 (12:43 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 13 Jan 2017 19:33:42 +0000 (14:33 -0500)
Fixed bug in new "ON CONFLICT DO UPDATE" feature where the "set"
values for the UPDATE clause would not be subject to type-level
processing, as normally takes effect to handle both user-defined
type level conversions as well as dialect-required conversions, such
as those required for JSON datatypes.   Additionally, clarified that
the keys in the set_ dictionary should match the "key" of the column,
if distinct from the column name.  A warning is emitted
for remaining column names that don't match column keys; for
compatibility reasons, these are emitted as they were previously.

Fixes: #3888
Change-Id: I67a04c67aa5f65e6d29f27bf3ef2f8257088d073

doc/build/changelog/changelog_11.rst
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/dml.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_on_conflict.py

index 5e3e275d2a7e58c5fd62fb1eb1f94782559f181f..3bfa7ba1b97475824d8c30a8a5b10de85fe44161 100644 (file)
 .. changelog::
     :version: 1.1.5
 
+    .. change:: 3888
+        :tags: bug, postgresql
+        :tickets: 3888
+
+        Fixed bug in new "ON CONFLICT DO UPDATE" feature where the "set"
+        values for the UPDATE clause would not be subject to type-level
+        processing, as normally takes effect to handle both user-defined
+        type level conversions as well as dialect-required conversions, such
+        as those required for JSON datatypes.   Additionally, clarified that
+        the keys in the set_ dictionary should match the "key" of the column,
+        if distinct from the column name.  A warning is emitted
+        for remaining column names that don't match column keys; for
+        compatibility reasons, these are emitted as they were previously.
+
     .. change:: 3872
         :tags: bug, examples
         :tickets: 3872
index b436b934fb2ecaeb0ae4ec313e90da70a90b2390..169b792f5dc1ed3ac75c5735737cb6c8e0897273 100644 (file)
@@ -862,6 +862,7 @@ import re
 import datetime as dt
 
 
+from sqlalchemy.sql import elements
 from ... import sql, schema, exc, util
 from ...engine import default, reflection
 from ...sql import compiler, expression
@@ -1499,17 +1500,52 @@ class PGCompiler(compiler.SQLCompiler):
         target_text = self._on_conflict_target(on_conflict, **kw)
 
         action_set_ops = []
-        for k, v in clause.update_values_to_set:
-            key_text = (
-                self.preparer.quote(k)
-                if isinstance(k, util.string_types)
-                else self.process(k, use_schema=False)
-            )
-            value_text = self.process(
-                v,
-                use_schema=False
+
+        set_parameters = dict(clause.update_values_to_set)
+        # create a list of column assignment clauses as tuples
+        cols = self.statement.table.c
+        for c in cols:
+            col_key = c.key
+            if col_key in set_parameters:
+                value = set_parameters.pop(col_key)
+                if elements._is_literal(value):
+                    value = elements.BindParameter(
+                        None, value, type_=c.type
+                    )
+
+                else:
+                    if isinstance(value, elements.BindParameter) and \
+                            value.type._isnull:
+                        value = value._clone()
+                        value.type = c.type
+                value_text = self.process(value.self_group(), use_schema=False)
+
+                key_text = (
+                    self.preparer.quote(col_key)
+                )
+                action_set_ops.append('%s = %s' % (key_text, value_text))
+
+        # check for names that don't match columns
+        if set_parameters:
+            util.warn(
+                "Additional column names not matching "
+                "any column keys in table '%s': %s" % (
+                    self.statement.table.name,
+                    (", ".join("'%s'" % c for c in set_parameters))
+                )
             )
-            action_set_ops.append('%s = %s' % (key_text, value_text))
+            for k, v in set_parameters.items():
+                key_text = (
+                    self.preparer.quote(k)
+                    if isinstance(k, util.string_types)
+                    else self.process(k, use_schema=False)
+                )
+                value_text = self.process(
+                    elements._literal_as_binds(v),
+                    use_schema=False
+                )
+                action_set_ops.append('%s = %s' % (key_text, value_text))
+
         action_text = ', '.join(action_set_ops)
         if clause.update_whereclause is not None:
             action_text += ' WHERE %s' % \
index df53fa8a2d164328247e68ad2bcaf4a77518d21f..bfdfbfa36498ef37146b54e59c212840ae28e889 100644 (file)
@@ -70,6 +70,8 @@ class Insert(StandardInsert):
         Required argument. A dictionary or other mapping object
         with column names as keys and expressions or literals as values,
         specifying the ``SET`` actions to take.
+        If the target :class:`.Column` specifies a ".key" attribute distinct
+        from the column name, that key should be used.
 
         .. warning:: This dictionary does **not** take into account
            Python-specified default UPDATE values or generation functions,
@@ -205,7 +207,7 @@ class OnConflictDoUpdate(OnConflictClause):
         if (not isinstance(set_, dict) or not set_):
             raise ValueError("set parameter must be a non-empty dictionary")
         self.update_values_to_set = [
-            (key, _literal_as_binds(value))
+            (key, value)
             for key, value in set_.items()
         ]
         self.update_whereclause = where
index 3e7f584bf74b0ca370e0623f6148c8b7c08dd697..99706bad847ed7e4ea512b3e94b679f96af87a56 100644 (file)
@@ -1,12 +1,12 @@
 # coding: utf-8
 
 from sqlalchemy.testing.assertions import AssertsCompiledSQL, is_, \
-    assert_raises, assert_raises_message
+    assert_raises, assert_raises_message, expect_warnings
 from sqlalchemy.testing import engines, fixtures
 from sqlalchemy import testing
 from sqlalchemy import Sequence, Table, Column, Integer, update, String,\
     func, MetaData, Enum, Index, and_, delete, select, cast, text, \
-    Text
+    Text, null
 from sqlalchemy.dialects.postgresql import ExcludeConstraint, array
 from sqlalchemy import exc, schema
 from sqlalchemy.dialects import postgresql
@@ -1089,7 +1089,7 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL):
             "(%(name)s) ON CONFLICT (myid) DO NOTHING"
         )
 
-    def test_do_update_set_clause_literal(self):
+    def test_do_update_set_clause_none(self):
         i = insert(self.table_with_metadata).values(myid=1, name='foo')
         i = i.on_conflict_do_update(
             index_elements=['myid'],
@@ -1097,6 +1097,25 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL):
                 ('name', "I'm a name"),
                 ('description', None)])
         )
+        self.assert_compile(
+            i,
+            'INSERT INTO mytable (myid, name) VALUES '
+            '(%(myid)s, %(name)s) ON CONFLICT (myid) '
+            'DO UPDATE SET name = %(param_1)s, '
+            'description = %(param_2)s',
+            {"myid": 1, "name": "foo",
+             "param_1": "I'm a name", "param_2": None}
+
+        )
+
+    def test_do_update_set_clause_literal(self):
+        i = insert(self.table_with_metadata).values(myid=1, name='foo')
+        i = i.on_conflict_do_update(
+            index_elements=['myid'],
+            set_=OrderedDict([
+                ('name', "I'm a name"),
+                ('description', null())])
+        )
         self.assert_compile(
             i,
             'INSERT INTO mytable (myid, name) VALUES '
@@ -1296,6 +1315,28 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL):
                             'DO UPDATE SET name = excluded.name '
                             "WHERE mytable.name != excluded.name")
 
+    def test_do_update_additional_colnames(self):
+        i = insert(
+            self.table1, values=dict(name='bar'))
+        i = i.on_conflict_do_update(
+            constraint=self.excl_constr_anon,
+            set_=dict(name='somename', unknown='unknown')
+        )
+        with expect_warnings(
+                "Additional column names not matching any "
+                "column keys in table 'mytable': 'unknown'"):
+            self.assert_compile(i,
+                                'INSERT INTO mytable (name) VALUES '
+                                "(%(name)s) ON CONFLICT (name, description) "
+                                "WHERE description != %(description_1)s "
+                                "DO UPDATE SET name = %(param_1)s, "
+                                "unknown = %(param_2)s",
+                                checkparams={
+                                    "name": "bar",
+                                    "description_1": "foo",
+                                    "param_1": "somename",
+                                    "param_2": "unknown"})
+
     def test_quote_raw_string_col(self):
         t = table('t', column("FancyName"), column("other name"))
 
index 9cfe4432a416477b26f6e548269fbc647466acdc..0e1dea06a7268bb7a6b5fe97b0495d70a126c628 100644 (file)
@@ -4,7 +4,7 @@ from sqlalchemy.testing.assertions import eq_, assert_raises
 from sqlalchemy.testing import fixtures
 from sqlalchemy import testing
 from sqlalchemy import Table, Column, Integer, String
-from sqlalchemy import exc, schema
+from sqlalchemy import exc, schema, types as sqltypes, sql
 from sqlalchemy.dialects.postgresql import insert
 
 
@@ -21,6 +21,18 @@ class OnConflictTest(fixtures.TablesTest):
             Column('name', String(50))
         )
 
+        class SpecialType(sqltypes.TypeDecorator):
+            impl = String
+
+            def process_bind_param(self, value, dialect):
+                return value + " processed"
+
+        Table(
+            'bind_targets', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', SpecialType())
+        )
+
         users_xtra = Table(
             'users_xtra', metadata,
             Column('id', Integer, primary_key=True),
@@ -473,3 +485,30 @@ class OnConflictTest(fixtures.TablesTest):
                     (2, 'name2', 'name2@gmail.com', 'not')
                 ]
             )
+
+    def test_on_conflict_do_update_special_types_in_set(self):
+        bind_targets = self.tables.bind_targets
+
+        with testing.db.connect() as conn:
+            i = insert(bind_targets)
+            conn.execute(i, {"id": 1, "data": "initial data"})
+
+            eq_(
+                conn.scalar(sql.select([bind_targets.c.data])),
+                "initial data processed"
+            )
+
+            i = insert(bind_targets)
+            i = i.on_conflict_do_update(
+                index_elements=[bind_targets.c.id],
+                set_=dict(data="new updated data")
+            )
+            conn.execute(
+                i, {"id": 1, "data": "new inserted data"}
+            )
+
+            eq_(
+                conn.scalar(sql.select([bind_targets.c.data])),
+                "new updated data processed"
+            )
+