From e1f316fe7f51671c1eca8ebfacf4267b2bb0a44c Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 21 Sep 2021 22:35:41 -0400 Subject: [PATCH] coerce for multivalues keys Fixed issue where using ORM column expressions as keys in the list of dictionaries passed to :meth:`_sql.Insert.values` for "multi-valued insert" would not be processed correctly into the correct column expressions. Fixes: #7060 Change-Id: I1c4c286c33ea6eeaafba617996828f5c88ff0a1c --- doc/build/changelog/unreleased_14/7060.rst | 7 +++ lib/sqlalchemy/sql/base.py | 6 ++- lib/sqlalchemy/sql/crud.py | 27 +++++++--- test/sql/test_insert.py | 58 +++++++++++++++++++--- 4 files changed, 83 insertions(+), 15 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/7060.rst diff --git a/doc/build/changelog/unreleased_14/7060.rst b/doc/build/changelog/unreleased_14/7060.rst new file mode 100644 index 0000000000..3df13259b4 --- /dev/null +++ b/doc/build/changelog/unreleased_14/7060.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, sql + :tickets: 7060 + + Fixed issue where using ORM column expressions as keys in the list of + dictionaries passed to :meth:`_sql.Insert.values` for "multi-valued insert" + would not be processed correctly into the correct column expressions. diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index d998e8e5c7..a6870f8d4b 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -121,7 +121,11 @@ def _exclusive_against(*names, **kw): ] @util.decorator - def check(fn, self, *args, **kw): + def check(fn, *args, **kw): + # make pylance happy by not including "self" in the argument + # list + self = args[0] + args = args[1:] for name, getter, default_ in getters: if getter(self) is not default_: msg = msgs.get( diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index d43f33ebb3..a9c9cb4c13 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -119,7 +119,7 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): # special logic that only occurs for multi-table UPDATE # statements if compile_state.isupdate and compile_state.is_multitable: - _get_multitable_params( + _get_update_multitable_params( compiler, stmt, compile_state, @@ -172,7 +172,12 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): if compile_state._has_multi_parameters: values = _extend_values_for_multiparams( - compiler, stmt, compile_state, values, kw + compiler, + stmt, + compile_state, + values, + _column_as_key, + kw, ) elif ( not values @@ -842,7 +847,7 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw): return _create_update_prefetch_bind_param(compiler, col, **kw) -def _get_multitable_params( +def _get_update_multitable_params( compiler, stmt, compile_state, @@ -918,15 +923,25 @@ def _get_multitable_params( compiler.postfetch.append(c) -def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw): +def _extend_values_for_multiparams( + compiler, + stmt, + compile_state, + values, + _column_as_key, + kw, +): values_0 = values values = [values] for i, row in enumerate(compile_state._multi_parameters[1:]): extension = [] + + row = {_column_as_key(key): v for key, v in row.items()} + for (col, col_expr, param) in values_0: - if col in row or col.key in row: - key = col if col in row else col.key + if col.key in row: + key = col.key if coercions._is_literal(row[key]): new_param = _create_bind_param( diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py index 6c2a5d9557..51045daac2 100644 --- a/test/sql/test_insert.py +++ b/test/sql/test_insert.py @@ -28,6 +28,14 @@ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures +class ORMExpr(object): + def __init__(self, col): + self.col = col + + def __clause_element__(self): + return self.col + + class _InsertTestBase(object): @classmethod def define_tables(cls, metadata): @@ -1126,13 +1134,33 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): dialect=dialect, ) - def test_named_with_column_objects(self): + @testing.combinations(("strings",), ("columns",), ("inspectables",)) + def test_named_with_column_objects(self, column_style): table1 = self.tables.mytable + if column_style == "strings": + myid, name, description = "myid", "name", "description" + + elif column_style == "columns": + myid, name, description = ( + table1.c.myid, + table1.c.name, + table1.c.description, + ) + elif column_style == "inspectables": + + myid, name, description = ( + ORMExpr(table1.c.myid), + ORMExpr(table1.c.name), + ORMExpr(table1.c.description), + ) + else: + assert False + values = [ - {table1.c.myid: 1, table1.c.name: "a", table1.c.description: "b"}, - {table1.c.myid: 2, table1.c.name: "c", table1.c.description: "d"}, - {table1.c.myid: 3, table1.c.name: "e", table1.c.description: "f"}, + {myid: 1, name: "a", description: "b"}, + {myid: 2, name: "c", description: "d"}, + {myid: 3, name: "e", description: "f"}, ] checkparams = { @@ -1304,7 +1332,8 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): dialect=postgresql.dialect(), ) - def test_python_scalar_default(self): + @testing.combinations(("strings",), ("columns",), ("inspectables",)) + def test_python_scalar_default(self, key_type): metadata = MetaData() table = Table( "sometable", @@ -1314,10 +1343,23 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): Column("foo", Integer, default=10), ) + if key_type == "strings": + id_, data, foo = "id", "data", "foo" + elif key_type == "columns": + id_, data, foo = table.c.id, table.c.data, table.c.foo + elif key_type == "inspectables": + id_, data, foo = ( + ORMExpr(table.c.id), + ORMExpr(table.c.data), + ORMExpr(table.c.foo), + ) + else: + assert False + values = [ - {"id": 1, "data": "data1"}, - {"id": 2, "data": "data2", "foo": 15}, - {"id": 3, "data": "data3"}, + {id_: 1, data: "data1"}, + {id_: 2, data: "data2", foo: 15}, + {id_: 3, data: "data3"}, ] checkparams = { -- 2.47.3