From: Mike Bayer Date: Tue, 24 Nov 2020 01:13:20 +0000 (-0500) Subject: Support Column objects in the SET clause for upsert X-Git-Tag: rel_1_4_0b2~132^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=584cabbf7e79948e38b29df5af63c3c712566f31;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support Column objects in the SET clause for upsert Established support for :class:`_schema.Column` objects as well as ORM instrumented attributes as keys in the ``set_`` dictionary passed to the :meth:`_postgresql.Insert.on_conflict_do_update` and :meth:`_sqlite.Insert.on_conflict_do_update` methods, which match to the :class:`_schema.Column` objects in the ``.c`` collection of the target :class:`_schema.Table`. Previously, only string column names were expected; a column expression would be assumed to be an out-of-table expression that would render fully along with a warning. Fixes: #5722 Change-Id: Ice73b501d721c28d978a0277a83cedc6aff756a9 --- diff --git a/doc/build/changelog/unreleased_14/5722.rst b/doc/build/changelog/unreleased_14/5722.rst new file mode 100644 index 0000000000..e756f8eb9c --- /dev/null +++ b/doc/build/changelog/unreleased_14/5722.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: bug, postgresql + :tickets: 5722 + :versions: 1.4.0b2 + + Established support for :class:`_schema.Column` objects as well as ORM + instrumented attributes as keys in the ``set_`` dictionary passed to the + :meth:`_postgresql.Insert.on_conflict_do_update` and + :meth:`_sqlite.Insert.on_conflict_do_update` methods, which match to the + :class:`_schema.Column` objects in the ``.c`` collection of the target + :class:`_schema.Table`. Previously, only string column names were + expected; a column expression would be assumed to be an out-of-table + expression that would render fully along with a warning. \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 3c33d9ee8e..3a458ebed4 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2147,22 +2147,28 @@ class PGCompiler(compiler.SQLCompiler): cols = insert_statement.table.c for c in cols: col_key = c.key + if col_key in set_parameters: value = set_parameters.pop(col_key) - if coercions._is_literal(value): - value = elements.BindParameter(None, value, type_=c.type) + elif c in set_parameters: + value = set_parameters.pop(c) + else: + continue - else: - if ( - isinstance(value, elements.BindParameter) - and value.type._isnull - ): - value = value._clone() - value.type = c.type - value_text = self.process(value.self_group(), use_schema=False) - - key_text = self.preparer.quote(col_key) - action_set_ops.append("%s = %s" % (key_text, value_text)) + if coercions._is_literal(value): + value = elements.BindParameter(None, value, type_=c.type) + + else: + if ( + isinstance(value, elements.BindParameter) + and value.type._isnull + ): + value = value._clone() + value.type = c.type + value_text = self.process(value.self_group(), use_schema=False) + + key_text = self.preparer.quote(col_key) + action_set_ops.append("%s = %s" % (key_text, value_text)) # check for names that don't match columns if set_parameters: diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py index 50fd095287..78cad974fc 100644 --- a/lib/sqlalchemy/dialects/postgresql/dml.py +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -7,6 +7,8 @@ from . import ext from ... import util +from ...sql import coercions +from ...sql import roles from ...sql import schema from ...sql.base import _generative from ...sql.dml import Insert as StandardInsert @@ -77,12 +79,16 @@ class Insert(StandardInsert): conditional target index. :param set\_: - Required argument. A dictionary or other mapping object - with column names as keys and expressions or literals as values, - specifying the ``SET`` actions to take. - If the target :class:`_schema.Column` specifies a ". - key" attribute distinct - from the column name, that key should be used. + A dictionary or other mapping object + where the keys are either names of columns in the target table, + or :class:`_schema.Column` objects or other ORM-mapped columns + matching that of the target table, and expressions or literals + as values, specifying the ``SET`` actions to take. + + .. versionadded:: 1.4 The + :paramref:`_postgresql.Insert.on_conflict_do_update.set_` + parameter supports :class:`_schema.Column` objects from the target + :class:`_schema.Table` as keys. .. warning:: This dictionary does **not** take into account Python-specified default UPDATE values or generation functions, @@ -229,6 +235,7 @@ class OnConflictDoUpdate(OnConflictClause): if not isinstance(set_, dict) or not set_: raise ValueError("set parameter must be a non-empty dictionary") self.update_values_to_set = [ - (key, value) for key, value in set_.items() + (coercions.expect(roles.DMLColumnRole, key), value) + for key, value in set_.items() ] self.update_whereclause = where diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 404a215b62..7c1bbb18ef 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1355,22 +1355,28 @@ class SQLiteCompiler(compiler.SQLCompiler): cols = insert_statement.table.c for c in cols: col_key = c.key + if col_key in set_parameters: value = set_parameters.pop(col_key) - if coercions._is_literal(value): - value = elements.BindParameter(None, value, type_=c.type) + elif c in set_parameters: + value = set_parameters.pop(c) + else: + continue - else: - if ( - isinstance(value, elements.BindParameter) - and value.type._isnull - ): - value = value._clone() - value.type = c.type - value_text = self.process(value.self_group(), use_schema=False) - - key_text = self.preparer.quote(col_key) - action_set_ops.append("%s = %s" % (key_text, value_text)) + if coercions._is_literal(value): + value = elements.BindParameter(None, value, type_=c.type) + + else: + if ( + isinstance(value, elements.BindParameter) + and value.type._isnull + ): + value = value._clone() + value.type = c.type + value_text = self.process(value.self_group(), use_schema=False) + + key_text = self.preparer.quote(col_key) + action_set_ops.append("%s = %s" % (key_text, value_text)) # check for names that don't match columns if set_parameters: diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py index a4d4d560cf..2d7ea6e4a8 100644 --- a/lib/sqlalchemy/dialects/sqlite/dml.py +++ b/lib/sqlalchemy/dialects/sqlite/dml.py @@ -5,6 +5,8 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from ... import util +from ...sql import coercions +from ...sql import roles from ...sql.base import _generative from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement @@ -65,12 +67,16 @@ class Insert(StandardInsert): conditional target index. :param set\_: - Required argument. A dictionary or other mapping object - with column names as keys and expressions or literals as values, - specifying the ``SET`` actions to take. - If the target :class:`_schema.Column` specifies a ". - key" attribute distinct - from the column name, that key should be used. + A dictionary or other mapping object + where the keys are either names of columns in the target table, + or :class:`_schema.Column` objects or other ORM-mapped columns + matching that of the target table, and expressions or literals + as values, specifying the ``SET`` actions to take. + + .. versionadded:: 1.4 The + :paramref:`_sqlite.Insert.on_conflict_do_update.set_` + parameter supports :class:`_schema.Column` objects from the target + :class:`_schema.Table` as keys. .. warning:: This dictionary does **not** take into account Python-specified default UPDATE values or generation functions, @@ -155,6 +161,7 @@ class OnConflictDoUpdate(OnConflictClause): if not isinstance(set_, dict) or not set_: raise ValueError("set parameter must be a non-empty dictionary") self.update_values_to_set = [ - (key, value) for key, value in set_.items() + (coercions.expect(roles.DMLColumnRole, key), value) + for key, value in set_.items() ] self.update_whereclause = where diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index a031c3df93..9651f7bd9d 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -1864,6 +1864,31 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL): }, ) + def test_do_update_set_clause_column_keys(self): + i = insert(self.table_with_metadata).values(myid=1, name="foo") + i = i.on_conflict_do_update( + index_elements=["myid"], + set_=OrderedDict( + [ + (self.table_with_metadata.c.name, "I'm a name"), + (self.table_with_metadata.c.description, None), + ] + ), + ) + self.assert_compile( + i, + "INSERT INTO mytable (myid, name) VALUES " + "(%(myid)s, %(name)s) ON CONFLICT (myid) " + "DO UPDATE SET name = %(param_1)s, " + "description = %(param_2)s", + { + "myid": 1, + "name": "foo", + "param_1": "I'm a name", + "param_2": None, + }, + ) + def test_do_update_set_clause_literal(self): i = insert(self.table_with_metadata).values(myid=1, name="foo") i = i.on_conflict_do_update( diff --git a/test/dialect/postgresql/test_on_conflict.py b/test/dialect/postgresql/test_on_conflict.py index 7a9bfd75da..7604878426 100644 --- a/test/dialect/postgresql/test_on_conflict.py +++ b/test/dialect/postgresql/test_on_conflict.py @@ -10,6 +10,7 @@ from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import types as sqltypes from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.testing import config from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import eq_ @@ -30,6 +31,14 @@ class OnConflictTest(fixtures.TablesTest): Column("name", String(50)), ) + Table( + "users_schema", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + schema=config.test_schema, + ) + class SpecialType(sqltypes.TypeDecorator): impl = String @@ -185,6 +194,99 @@ class OnConflictTest(fixtures.TablesTest): [(1, "name1")], ) + def test_on_conflict_do_update_schema(self): + users = self.tables.get("%s.users_schema" % config.test_schema) + + with testing.db.connect() as conn: + conn.execute(users.insert(), dict(id=1, name="name1")) + + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.id], set_=dict(name=i.excluded.name) + ) + result = conn.execute(i, dict(id=1, name="name1")) + + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) + + eq_( + conn.execute(users.select().where(users.c.id == 1)).fetchall(), + [(1, "name1")], + ) + + def test_on_conflict_do_update_column_as_key_set(self): + users = self.tables.users + + with testing.db.connect() as conn: + conn.execute(users.insert(), dict(id=1, name="name1")) + + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: i.excluded.name}, + ) + result = conn.execute(i, dict(id=1, name="name1")) + + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) + + eq_( + conn.execute(users.select().where(users.c.id == 1)).fetchall(), + [(1, "name1")], + ) + + def test_on_conflict_do_update_clauseelem_as_key_set(self): + users = self.tables.users + + class MyElem(object): + def __init__(self, expr): + self.expr = expr + + def __clause_element__(self): + return self.expr + + with testing.db.connect() as conn: + conn.execute( + users.insert(), + {"id": 1, "name": "name1"}, + ) + + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.id], + set_={MyElem(users.c.name): i.excluded.name}, + ).values({MyElem(users.c.id): 1, MyElem(users.c.name): "name1"}) + result = conn.execute(i) + + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) + + eq_( + conn.execute(users.select().where(users.c.id == 1)).fetchall(), + [(1, "name1")], + ) + + def test_on_conflict_do_update_column_as_key_set_schema(self): + users = self.tables.get("%s.users_schema" % config.test_schema) + + with testing.db.connect() as conn: + conn.execute(users.insert(), dict(id=1, name="name1")) + + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: i.excluded.name}, + ) + result = conn.execute(i, dict(id=1, name="name1")) + + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) + + eq_( + conn.execute(users.select().where(users.c.id == 1)).fetchall(), + [(1, "name1")], + ) + def test_on_conflict_do_update_two(self): users = self.tables.users diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index 456bad7bdf..f8b50f8883 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -2921,6 +2921,53 @@ class OnConflictTest(fixtures.TablesTest): [(10, "I'm a name")], ) + def test_on_conflict_do_update_column_keys(self, connection): + users = self.tables.users + + conn = connection + conn.execute(users.insert(), dict(id=1, name="name1")) + + i = insert(users) + i = i.on_conflict_do_update( + index_elements=users.primary_key.columns, + set_={users.c.id: 10, users.c.name: "I'm a name"}, + ).values(id=1, name="name4") + + result = conn.execute(i) + eq_(result.inserted_primary_key, (1,)) + + eq_( + conn.execute(users.select().where(users.c.id == 10)).fetchall(), + [(10, "I'm a name")], + ) + + def test_on_conflict_do_update_clauseelem_keys(self, connection): + users = self.tables.users + + class MyElem(object): + def __init__(self, expr): + self.expr = expr + + def __clause_element__(self): + return self.expr + + conn = connection + conn.execute(users.insert(), dict(id=1, name="name1")) + + i = insert(users) + i = i.on_conflict_do_update( + index_elements=users.primary_key.columns, + set_={MyElem(users.c.id): 10, MyElem(users.c.name): "I'm a name"}, + ).values({MyElem(users.c.id): 1, MyElem(users.c.name): "name4"}) + + result = conn.execute(i) + eq_(result.inserted_primary_key, (1,)) + + eq_( + conn.execute(users.select().where(users.c.id == 10)).fetchall(), + [(10, "I'm a name")], + ) + def test_on_conflict_do_update_multivalues(self, connection): users = self.tables.users