From afd78a37dafe8e84e23bccfb570bd758797e2142 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 13 Jan 2017 12:43:24 -0500 Subject: [PATCH] Use full column->type processing for ON CONFLICT SET clause Fixed bug in new "ON CONFLICT DO UPDATE" feature where the "set" values for the UPDATE clause would not be subject to type-level processing, as normally takes effect to handle both user-defined type level conversions as well as dialect-required conversions, such as those required for JSON datatypes. Additionally, clarified that the keys in the set_ dictionary should match the "key" of the column, if distinct from the column name. A warning is emitted for remaining column names that don't match column keys; for compatibility reasons, these are emitted as they were previously. Fixes: #3888 Change-Id: I67a04c67aa5f65e6d29f27bf3ef2f8257088d073 --- doc/build/changelog/changelog_11.rst | 14 ++++++ lib/sqlalchemy/dialects/postgresql/base.py | 56 +++++++++++++++++---- lib/sqlalchemy/dialects/postgresql/dml.py | 4 +- test/dialect/postgresql/test_compiler.py | 47 +++++++++++++++-- test/dialect/postgresql/test_on_conflict.py | 41 ++++++++++++++- 5 files changed, 147 insertions(+), 15 deletions(-) diff --git a/doc/build/changelog/changelog_11.rst b/doc/build/changelog/changelog_11.rst index 5e3e275d2a..3bfa7ba1b9 100644 --- a/doc/build/changelog/changelog_11.rst +++ b/doc/build/changelog/changelog_11.rst @@ -21,6 +21,20 @@ .. changelog:: :version: 1.1.5 + .. change:: 3888 + :tags: bug, postgresql + :tickets: 3888 + + Fixed bug in new "ON CONFLICT DO UPDATE" feature where the "set" + values for the UPDATE clause would not be subject to type-level + processing, as normally takes effect to handle both user-defined + type level conversions as well as dialect-required conversions, such + as those required for JSON datatypes. Additionally, clarified that + the keys in the set_ dictionary should match the "key" of the column, + if distinct from the column name. A warning is emitted + for remaining column names that don't match column keys; for + compatibility reasons, these are emitted as they were previously. + .. change:: 3872 :tags: bug, examples :tickets: 3872 diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index b436b934fb..169b792f5d 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -862,6 +862,7 @@ import re import datetime as dt +from sqlalchemy.sql import elements from ... import sql, schema, exc, util from ...engine import default, reflection from ...sql import compiler, expression @@ -1499,17 +1500,52 @@ class PGCompiler(compiler.SQLCompiler): target_text = self._on_conflict_target(on_conflict, **kw) action_set_ops = [] - for k, v in clause.update_values_to_set: - key_text = ( - self.preparer.quote(k) - if isinstance(k, util.string_types) - else self.process(k, use_schema=False) - ) - value_text = self.process( - v, - use_schema=False + + set_parameters = dict(clause.update_values_to_set) + # create a list of column assignment clauses as tuples + cols = self.statement.table.c + for c in cols: + col_key = c.key + if col_key in set_parameters: + value = set_parameters.pop(col_key) + if elements._is_literal(value): + value = elements.BindParameter( + None, value, type_=c.type + ) + + else: + if isinstance(value, elements.BindParameter) and \ + value.type._isnull: + value = value._clone() + value.type = c.type + value_text = self.process(value.self_group(), use_schema=False) + + key_text = ( + self.preparer.quote(col_key) + ) + action_set_ops.append('%s = %s' % (key_text, value_text)) + + # check for names that don't match columns + if set_parameters: + util.warn( + "Additional column names not matching " + "any column keys in table '%s': %s" % ( + self.statement.table.name, + (", ".join("'%s'" % c for c in set_parameters)) + ) ) - action_set_ops.append('%s = %s' % (key_text, value_text)) + for k, v in set_parameters.items(): + key_text = ( + self.preparer.quote(k) + if isinstance(k, util.string_types) + else self.process(k, use_schema=False) + ) + value_text = self.process( + elements._literal_as_binds(v), + use_schema=False + ) + action_set_ops.append('%s = %s' % (key_text, value_text)) + action_text = ', '.join(action_set_ops) if clause.update_whereclause is not None: action_text += ' WHERE %s' % \ diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py index df53fa8a2d..bfdfbfa364 100644 --- a/lib/sqlalchemy/dialects/postgresql/dml.py +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -70,6 +70,8 @@ class Insert(StandardInsert): Required argument. A dictionary or other mapping object with column names as keys and expressions or literals as values, specifying the ``SET`` actions to take. + If the target :class:`.Column` specifies a ".key" attribute distinct + from the column name, that key should be used. .. warning:: This dictionary does **not** take into account Python-specified default UPDATE values or generation functions, @@ -205,7 +207,7 @@ class OnConflictDoUpdate(OnConflictClause): if (not isinstance(set_, dict) or not set_): raise ValueError("set parameter must be a non-empty dictionary") self.update_values_to_set = [ - (key, _literal_as_binds(value)) + (key, value) for key, value in set_.items() ] self.update_whereclause = where diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 3e7f584bf7..99706bad84 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -1,12 +1,12 @@ # coding: utf-8 from sqlalchemy.testing.assertions import AssertsCompiledSQL, is_, \ - assert_raises, assert_raises_message + assert_raises, assert_raises_message, expect_warnings from sqlalchemy.testing import engines, fixtures from sqlalchemy import testing from sqlalchemy import Sequence, Table, Column, Integer, update, String,\ func, MetaData, Enum, Index, and_, delete, select, cast, text, \ - Text + Text, null from sqlalchemy.dialects.postgresql import ExcludeConstraint, array from sqlalchemy import exc, schema from sqlalchemy.dialects import postgresql @@ -1089,7 +1089,7 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL): "(%(name)s) ON CONFLICT (myid) DO NOTHING" ) - def test_do_update_set_clause_literal(self): + def test_do_update_set_clause_none(self): i = insert(self.table_with_metadata).values(myid=1, name='foo') i = i.on_conflict_do_update( index_elements=['myid'], @@ -1097,6 +1097,25 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL): ('name', "I'm a name"), ('description', None)]) ) + self.assert_compile( + i, + 'INSERT INTO mytable (myid, name) VALUES ' + '(%(myid)s, %(name)s) ON CONFLICT (myid) ' + 'DO UPDATE SET name = %(param_1)s, ' + 'description = %(param_2)s', + {"myid": 1, "name": "foo", + "param_1": "I'm a name", "param_2": None} + + ) + + def test_do_update_set_clause_literal(self): + i = insert(self.table_with_metadata).values(myid=1, name='foo') + i = i.on_conflict_do_update( + index_elements=['myid'], + set_=OrderedDict([ + ('name', "I'm a name"), + ('description', null())]) + ) self.assert_compile( i, 'INSERT INTO mytable (myid, name) VALUES ' @@ -1296,6 +1315,28 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL): 'DO UPDATE SET name = excluded.name ' "WHERE mytable.name != excluded.name") + def test_do_update_additional_colnames(self): + i = insert( + self.table1, values=dict(name='bar')) + i = i.on_conflict_do_update( + constraint=self.excl_constr_anon, + set_=dict(name='somename', unknown='unknown') + ) + with expect_warnings( + "Additional column names not matching any " + "column keys in table 'mytable': 'unknown'"): + self.assert_compile(i, + 'INSERT INTO mytable (name) VALUES ' + "(%(name)s) ON CONFLICT (name, description) " + "WHERE description != %(description_1)s " + "DO UPDATE SET name = %(param_1)s, " + "unknown = %(param_2)s", + checkparams={ + "name": "bar", + "description_1": "foo", + "param_1": "somename", + "param_2": "unknown"}) + def test_quote_raw_string_col(self): t = table('t', column("FancyName"), column("other name")) diff --git a/test/dialect/postgresql/test_on_conflict.py b/test/dialect/postgresql/test_on_conflict.py index 9cfe4432a4..0e1dea06a7 100644 --- a/test/dialect/postgresql/test_on_conflict.py +++ b/test/dialect/postgresql/test_on_conflict.py @@ -4,7 +4,7 @@ from sqlalchemy.testing.assertions import eq_, assert_raises from sqlalchemy.testing import fixtures from sqlalchemy import testing from sqlalchemy import Table, Column, Integer, String -from sqlalchemy import exc, schema +from sqlalchemy import exc, schema, types as sqltypes, sql from sqlalchemy.dialects.postgresql import insert @@ -21,6 +21,18 @@ class OnConflictTest(fixtures.TablesTest): Column('name', String(50)) ) + class SpecialType(sqltypes.TypeDecorator): + impl = String + + def process_bind_param(self, value, dialect): + return value + " processed" + + Table( + 'bind_targets', metadata, + Column('id', Integer, primary_key=True), + Column('data', SpecialType()) + ) + users_xtra = Table( 'users_xtra', metadata, Column('id', Integer, primary_key=True), @@ -473,3 +485,30 @@ class OnConflictTest(fixtures.TablesTest): (2, 'name2', 'name2@gmail.com', 'not') ] ) + + def test_on_conflict_do_update_special_types_in_set(self): + bind_targets = self.tables.bind_targets + + with testing.db.connect() as conn: + i = insert(bind_targets) + conn.execute(i, {"id": 1, "data": "initial data"}) + + eq_( + conn.scalar(sql.select([bind_targets.c.data])), + "initial data processed" + ) + + i = insert(bind_targets) + i = i.on_conflict_do_update( + index_elements=[bind_targets.c.id], + set_=dict(data="new updated data") + ) + conn.execute( + i, {"id": 1, "data": "new inserted data"} + ) + + eq_( + conn.scalar(sql.select([bind_targets.c.data])), + "new updated data processed" + ) + -- 2.47.2