]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
coerce for multivalues keys
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 22 Sep 2021 02:35:41 +0000 (22:35 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 22 Sep 2021 02:36:27 +0000 (22:36 -0400)
Fixed issue where using ORM column expressions as keys in the list of
dictionaries passed to :meth:`_sql.Insert.values` for "multi-valued insert"
would not be processed correctly into the correct column expressions.

Fixes: #7060
Change-Id: I1c4c286c33ea6eeaafba617996828f5c88ff0a1c

doc/build/changelog/unreleased_14/7060.rst [new file with mode: 0644]
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/crud.py
test/sql/test_insert.py

diff --git a/doc/build/changelog/unreleased_14/7060.rst b/doc/build/changelog/unreleased_14/7060.rst
new file mode 100644 (file)
index 0000000..3df1325
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 7060
+
+    Fixed issue where using ORM column expressions as keys in the list of
+    dictionaries passed to :meth:`_sql.Insert.values` for "multi-valued insert"
+    would not be processed correctly into the correct column expressions.
index d998e8e5c70dd0711056b56d6dca40398342eea5..a6870f8d4bf4f2648da8112324732bb3dbbc182a 100644 (file)
@@ -121,7 +121,11 @@ def _exclusive_against(*names, **kw):
     ]
 
     @util.decorator
-    def check(fn, self, *args, **kw):
+    def check(fn, *args, **kw):
+        # make pylance happy by not including "self" in the argument
+        # list
+        self = args[0]
+        args = args[1:]
         for name, getter, default_ in getters:
             if getter(self) is not default_:
                 msg = msgs.get(
index d43f33ebb38674a9b4eb3054384427647d77adad..a9c9cb4c133344ddafbae2480f41b16c0725fdaa 100644 (file)
@@ -119,7 +119,7 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
     # special logic that only occurs for multi-table UPDATE
     # statements
     if compile_state.isupdate and compile_state.is_multitable:
-        _get_multitable_params(
+        _get_update_multitable_params(
             compiler,
             stmt,
             compile_state,
@@ -172,7 +172,12 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
 
     if compile_state._has_multi_parameters:
         values = _extend_values_for_multiparams(
-            compiler, stmt, compile_state, values, kw
+            compiler,
+            stmt,
+            compile_state,
+            values,
+            _column_as_key,
+            kw,
         )
     elif (
         not values
@@ -842,7 +847,7 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
             return _create_update_prefetch_bind_param(compiler, col, **kw)
 
 
-def _get_multitable_params(
+def _get_update_multitable_params(
     compiler,
     stmt,
     compile_state,
@@ -918,15 +923,25 @@ def _get_multitable_params(
                 compiler.postfetch.append(c)
 
 
-def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw):
+def _extend_values_for_multiparams(
+    compiler,
+    stmt,
+    compile_state,
+    values,
+    _column_as_key,
+    kw,
+):
     values_0 = values
     values = [values]
 
     for i, row in enumerate(compile_state._multi_parameters[1:]):
         extension = []
+
+        row = {_column_as_key(key): v for key, v in row.items()}
+
         for (col, col_expr, param) in values_0:
-            if col in row or col.key in row:
-                key = col if col in row else col.key
+            if col.key in row:
+                key = col.key
 
                 if coercions._is_literal(row[key]):
                     new_param = _create_bind_param(
index 6c2a5d9557a53991c47e32752647c9adad28cfd5..51045daac223bf4d78f42c5ef007977029733372 100644 (file)
@@ -28,6 +28,14 @@ from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 
 
+class ORMExpr(object):
+    def __init__(self, col):
+        self.col = col
+
+    def __clause_element__(self):
+        return self.col
+
+
 class _InsertTestBase(object):
     @classmethod
     def define_tables(cls, metadata):
@@ -1126,13 +1134,33 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             dialect=dialect,
         )
 
-    def test_named_with_column_objects(self):
+    @testing.combinations(("strings",), ("columns",), ("inspectables",))
+    def test_named_with_column_objects(self, column_style):
         table1 = self.tables.mytable
 
+        if column_style == "strings":
+            myid, name, description = "myid", "name", "description"
+
+        elif column_style == "columns":
+            myid, name, description = (
+                table1.c.myid,
+                table1.c.name,
+                table1.c.description,
+            )
+        elif column_style == "inspectables":
+
+            myid, name, description = (
+                ORMExpr(table1.c.myid),
+                ORMExpr(table1.c.name),
+                ORMExpr(table1.c.description),
+            )
+        else:
+            assert False
+
         values = [
-            {table1.c.myid: 1, table1.c.name: "a", table1.c.description: "b"},
-            {table1.c.myid: 2, table1.c.name: "c", table1.c.description: "d"},
-            {table1.c.myid: 3, table1.c.name: "e", table1.c.description: "f"},
+            {myid: 1, name: "a", description: "b"},
+            {myid: 2, name: "c", description: "d"},
+            {myid: 3, name: "e", description: "f"},
         ]
 
         checkparams = {
@@ -1304,7 +1332,8 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             dialect=postgresql.dialect(),
         )
 
-    def test_python_scalar_default(self):
+    @testing.combinations(("strings",), ("columns",), ("inspectables",))
+    def test_python_scalar_default(self, key_type):
         metadata = MetaData()
         table = Table(
             "sometable",
@@ -1314,10 +1343,23 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             Column("foo", Integer, default=10),
         )
 
+        if key_type == "strings":
+            id_, data, foo = "id", "data", "foo"
+        elif key_type == "columns":
+            id_, data, foo = table.c.id, table.c.data, table.c.foo
+        elif key_type == "inspectables":
+            id_, data, foo = (
+                ORMExpr(table.c.id),
+                ORMExpr(table.c.data),
+                ORMExpr(table.c.foo),
+            )
+        else:
+            assert False
+
         values = [
-            {"id": 1, "data": "data1"},
-            {"id": 2, "data": "data2", "foo": 15},
-            {"id": 3, "data": "data3"},
+            {id_: 1, data: "data1"},
+            {id_: 2, data: "data2", foo: 15},
+            {id_: 3, data: "data3"},
         ]
 
         checkparams = {