]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
accommodate arbitrary embedded params in insertmanyvalues
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 15 Oct 2022 19:20:21 +0000 (15:20 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 16 Oct 2022 12:47:47 +0000 (08:47 -0400)
Fixed bug in new "insertmanyvalues" feature where INSERT that included a
subquery with :func:`_sql.bindparam` inside of it would fail to render
correctly in "insertmanyvalues" format. This affected psycopg2 most
directly as "insertmanyvalues" is used unconditionally with this driver.

Fixes: #8639
Change-Id: I67903fa86afe208899d4f23f940e0727d1be2ce3

doc/build/changelog/unreleased_20/8639.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/crud.py
test/sql/test_insert_exec.py
test/sql/test_type_expressions.py

diff --git a/doc/build/changelog/unreleased_20/8639.rst b/doc/build/changelog/unreleased_20/8639.rst
new file mode 100644 (file)
index 0000000..46cce75
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, regression, sql
+    :tickets: 8639
+
+    Fixed bug in new "insertmanyvalues" feature where INSERT that included a
+    subquery with :func:`_sql.bindparam` inside of it would fail to render
+    correctly in "insertmanyvalues" format. This affected psycopg2 most
+    directly as "insertmanyvalues" is used unconditionally with this driver.
+
index dd40bfe345505a823ef5dcf1383aa3d190b55c4a..efe0ea2b439ffd2c48e856a62bf4a4069cf24032 100644 (file)
@@ -94,7 +94,6 @@ if typing.TYPE_CHECKING:
     from .elements import BindParameter
     from .elements import ColumnClause
     from .elements import ColumnElement
-    from .elements import KeyedColumnElement
     from .elements import Label
     from .functions import Function
     from .selectable import AliasedReturnsRows
@@ -236,6 +235,7 @@ BIND_TEMPLATES = {
     "named": ":%(name)s",
 }
 
+
 _BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]")
 _BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__"))
 
@@ -416,7 +416,7 @@ class _InsertManyValues(NamedTuple):
 
     is_default_expr: bool
     single_values_expr: str
-    insert_crud_params: List[Tuple[KeyedColumnElement[Any], str, str]]
+    insert_crud_params: List[crud._CrudParamElementStr]
     num_positional_params_counted: int
 
 
@@ -2960,6 +2960,7 @@ class SQLCompiler(Compiled):
         skip_bind_expression=False,
         literal_execute=False,
         render_postcompile=False,
+        accumulate_bind_names=None,
         **kwargs,
     ):
         if not skip_bind_expression:
@@ -2973,6 +2974,7 @@ class SQLCompiler(Compiled):
                     literal_binds=literal_binds,
                     literal_execute=literal_execute,
                     render_postcompile=render_postcompile,
+                    accumulate_bind_names=accumulate_bind_names,
                     **kwargs,
                 )
                 if bindparam.expanding:
@@ -3063,6 +3065,9 @@ class SQLCompiler(Compiled):
 
         self.binds[bindparam.key] = self.binds[name] = bindparam
 
+        if accumulate_bind_names is not None:
+            accumulate_bind_names.add(name)
+
         # if we are given a cache key that we're going to match against,
         # relate the bindparam here to one that is most likely present
         # in the "extracted params" portion of the cache key.  this is used
@@ -4646,13 +4651,25 @@ class SQLCompiler(Compiled):
 
             all_keys = set(parameters[0])
 
-            escaped_insert_crud_params: Sequence[Any] = [
-                (escaped_bind_names.get(col.key, col.key), formatted)
-                for col, _, formatted in insert_crud_params
-            ]
+            def apply_placeholders(keys, formatted):
+                for key in keys:
+                    key = escaped_bind_names.get(key, key)
+                    formatted = formatted.replace(
+                        self.bindtemplate % {"name": key},
+                        self.bindtemplate
+                        % {"name": f"{key}__EXECMANY_INDEX__"},
+                    )
+                return formatted
+
+            formatted_values_clause = f"""({', '.join(
+                apply_placeholders(bind_keys, formatted)
+                for _, _, formatted, bind_keys in insert_crud_params
+            )})"""
 
             keys_to_replace = all_keys.intersection(
-                key for key, _ in escaped_insert_crud_params
+                escaped_bind_names.get(key, key)
+                for _, _, _, bind_keys in insert_crud_params
+                for key in bind_keys
             )
             base_parameters = {
                 key: parameters[0][key]
@@ -4660,7 +4677,7 @@ class SQLCompiler(Compiled):
             }
             executemany_values_w_comma = ""
         else:
-            escaped_insert_crud_params = ()
+            formatted_values_clause = ""
             keys_to_replace = set()
             base_parameters = {}
             executemany_values_w_comma = f"({imv.single_values_expr}), "
@@ -4723,14 +4740,10 @@ class SQLCompiler(Compiled):
                 replaced_parameters = base_parameters.copy()
 
                 for i, param in enumerate(batch):
-                    new_tokens = [
-                        formatted.replace(key, f"{key}__{i}")
-                        if key in param
-                        else formatted
-                        for key, formatted in escaped_insert_crud_params
-                    ]
                     replaced_values_clauses.append(
-                        f"({', '.join(new_tokens)})"
+                        formatted_values_clause.replace(
+                            "EXECMANY_INDEX__", str(i)
+                        )
                     )
 
                     replaced_parameters.update(
@@ -4841,7 +4854,7 @@ class SQLCompiler(Compiled):
 
         if crud_params_single or not supports_default_values:
             text += " (%s)" % ", ".join(
-                [expr for _, expr, _ in crud_params_single]
+                [expr for _, expr, _, _ in crud_params_single]
             )
 
         if self.implicit_returning or insert_stmt._returning:
@@ -4902,8 +4915,7 @@ class SQLCompiler(Compiled):
                     True,
                     self.dialect.default_metavalue_token,
                     cast(
-                        "List[Tuple[KeyedColumnElement[Any], str, str]]",
-                        crud_params_single,
+                        "List[crud._CrudParamElementStr]", crud_params_single
                     ),
                     (positiontup_after - positiontup_before),
                 )
@@ -4911,7 +4923,7 @@ class SQLCompiler(Compiled):
             text += " VALUES %s" % (
                 ", ".join(
                     "(%s)"
-                    % (", ".join(value for _, _, value in crud_param_set))
+                    % (", ".join(value for _, _, value, _ in crud_param_set))
                     for crud_param_set in crud_params_struct.all_multi_params
                 )
             )
@@ -4921,8 +4933,9 @@ class SQLCompiler(Compiled):
             insert_single_values_expr = ", ".join(
                 [
                     value
-                    for _, _, value in cast(
-                        "List[Tuple[Any, Any, str]]", crud_params_single
+                    for _, _, value, _ in cast(
+                        "List[crud._CrudParamElementStr]",
+                        crud_params_single,
                     )
                 ]
             )
@@ -4935,7 +4948,7 @@ class SQLCompiler(Compiled):
                     False,
                     insert_single_values_expr,
                     cast(
-                        "List[Tuple[KeyedColumnElement[Any], str, str]]",
+                        "List[crud._CrudParamElementStr]",
                         crud_params_single,
                     ),
                     positiontup_after - positiontup_before,
@@ -5058,8 +5071,8 @@ class SQLCompiler(Compiled):
         text += " SET "
         text += ", ".join(
             expr + "=" + value
-            for _, expr, value in cast(
-                "List[Tuple[Any, str, str]]", crud_params
+            for _, expr, value, _ in cast(
+                "List[Tuple[Any, str, str, Any]]", crud_params
             )
         )
 
index 22fffb73a1a2e210499cf7cdfbb2ecc4765a2a84..31d127c2c34b688ba6de5881379b6d912b3badf0 100644 (file)
@@ -18,12 +18,14 @@ from typing import Any
 from typing import Callable
 from typing import cast
 from typing import Dict
+from typing import Iterable
 from typing import List
 from typing import MutableMapping
 from typing import NamedTuple
 from typing import Optional
 from typing import overload
 from typing import Sequence
+from typing import Set
 from typing import Tuple
 from typing import TYPE_CHECKING
 from typing import Union
@@ -49,6 +51,7 @@ if TYPE_CHECKING:
     from .dml import DMLState
     from .dml import ValuesBase
     from .elements import ColumnElement
+    from .elements import KeyedColumnElement
     from .schema import _SQLExprDefault
 
 REQUIRED = util.symbol(
@@ -74,18 +77,32 @@ def _as_dml_column(c: ColumnElement[Any]) -> ColumnClause[Any]:
     return c
 
 
-_CrudParamSequence = Sequence[
-    Tuple[
-        "ColumnElement[Any]",
-        str,
-        Optional[Union[str, "_SQLExprDefault"]],
-    ]
+_CrudParamElement = Tuple[
+    "ColumnElement[Any]",
+    str,
+    Optional[Union[str, "_SQLExprDefault"]],
+    Iterable[str],
+]
+_CrudParamElementStr = Tuple[
+    "KeyedColumnElement[Any]",
+    str,
+    str,
+    Iterable[str],
 ]
+_CrudParamElementSQLExpr = Tuple[
+    "ColumnClause[Any]",
+    str,
+    "_SQLExprDefault",
+    Iterable[str],
+]
+
+_CrudParamSequence = List[_CrudParamElement]
 
 
 class _CrudParams(NamedTuple):
     single_params: _CrudParamSequence
-    all_multi_params: List[Sequence[Tuple[ColumnClause[Any], str, str]]]
+
+    all_multi_params: List[Sequence[_CrudParamElementStr]]
 
 
 def _get_crud_params(
@@ -175,6 +192,7 @@ def _get_crud_params(
                     c,
                     compiler.preparer.format_column(c),
                     _create_bind_param(compiler, c, None, required=True),
+                    (c.key,),
                 )
                 for c in stmt.table.columns
             ],
@@ -220,9 +238,7 @@ def _get_crud_params(
         )
 
     # create a list of column assignment clauses as tuples
-    values: List[
-        Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]]
-    ] = []
+    values: List[_CrudParamElement] = []
 
     if stmt_parameter_tuples is not None:
         _get_stmt_parameter_tuples_params(
@@ -307,7 +323,10 @@ def _get_crud_params(
             compiler,
             stmt,
             compile_state,
-            cast("Sequence[Tuple[ColumnClause[Any], str, str]]", values),
+            cast(
+                "Sequence[_CrudParamElementStr]",
+                values,
+            ),
             cast("Callable[..., str]", _column_as_key),
             kw,
         )
@@ -326,6 +345,7 @@ def _get_crud_params(
                 _as_dml_column(stmt.table.columns[0]),
                 compiler.preparer.format_column(stmt.table.columns[0]),
                 compiler.dialect.default_metavalue_token,
+                (),
             )
         ]
 
@@ -488,7 +508,7 @@ def _scan_insert_from_select_cols(
 
     compiler.stack[-1]["insert_from_select"] = stmt.select
 
-    add_select_cols: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]] = []
+    add_select_cols: List[_CrudParamElementSQLExpr] = []
     if stmt.include_insert_from_select_defaults:
         col_set = set(cols)
         for col in stmt.table.columns:
@@ -499,7 +519,7 @@ def _scan_insert_from_select_cols(
         col_key = _getattr_col_key(c)
         if col_key in parameters and col_key not in check_columns:
             parameters.pop(col_key)
-            values.append((c, compiler.preparer.format_column(c), None))
+            values.append((c, compiler.preparer.format_column(c), None, ()))
         else:
             _append_param_insert_select_hasdefault(
                 compiler, stmt, c, add_select_cols, kw
@@ -513,7 +533,7 @@ def _scan_insert_from_select_cols(
                 f"Can't extend statement for INSERT..FROM SELECT to include "
                 f"additional default-holding column(s) "
                 f"""{
-                    ', '.join(repr(key) for _, key, _ in add_select_cols)
+                    ', '.join(repr(key) for _, key, _, _ in add_select_cols)
                 }.  Convert the selectable to a subquery() first, or pass """
                 "include_defaults=False to Insert.from_select() to skip these "
                 "columns."
@@ -521,7 +541,7 @@ def _scan_insert_from_select_cols(
         ins_from_select = ins_from_select._generate()
         # copy raw_columns
         ins_from_select._raw_columns = list(ins_from_select._raw_columns) + [
-            expr for _, _, expr in add_select_cols
+            expr for _, _, expr, _ in add_select_cols
         ]
         compiler.stack[-1]["insert_from_select"] = ins_from_select
 
@@ -748,6 +768,8 @@ def _append_param_parameter(
         c, use_table=compile_state.include_table_with_column_exprs
     )
 
+    accumulated_bind_names: Set[str] = set()
+
     if coercions._is_literal(value):
 
         if (
@@ -772,6 +794,7 @@ def _append_param_parameter(
             if not _compile_state_isinsert(compile_state)
             or not compile_state._has_multi_parameters
             else "%s_m0" % _col_bind_name(c),
+            accumulate_bind_names=accumulated_bind_names,
             **kw,
         )
     elif value._is_bind_parameter:
@@ -796,11 +819,16 @@ def _append_param_parameter(
             if not _compile_state_isinsert(compile_state)
             or not compile_state._has_multi_parameters
             else "%s_m0" % _col_bind_name(c),
+            accumulate_bind_names=accumulated_bind_names,
             **kw,
         )
     else:
         # value is a SQL expression
-        value = compiler.process(value.self_group(), **kw)
+        value = compiler.process(
+            value.self_group(),
+            accumulate_bind_names=accumulated_bind_names,
+            **kw,
+        )
 
         if compile_state.isupdate:
             if implicit_return_defaults and c in implicit_return_defaults:
@@ -828,7 +856,7 @@ def _append_param_parameter(
 
                 compiler.postfetch.append(c)
 
-    values.append((c, col_value, value))
+    values.append((c, col_value, value, accumulated_bind_names))
 
 
 def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
@@ -843,20 +871,32 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
                 not c.default.optional
                 or not compiler.dialect.sequences_optional
             ):
+                accumulated_bind_names: Set[str] = set()
                 values.append(
                     (
                         c,
                         compiler.preparer.format_column(c),
-                        compiler.process(c.default, **kw),
+                        compiler.process(
+                            c.default,
+                            accumulate_bind_names=accumulated_bind_names,
+                            **kw,
+                        ),
+                        accumulated_bind_names,
                     )
                 )
             compiler.implicit_returning.append(c)
         elif c.default.is_clause_element:
+            accumulated_bind_names = set()
             values.append(
                 (
                     c,
                     compiler.preparer.format_column(c),
-                    compiler.process(c.default.arg.self_group(), **kw),
+                    compiler.process(
+                        c.default.arg.self_group(),
+                        accumulate_bind_names=accumulated_bind_names,
+                        **kw,
+                    ),
+                    accumulated_bind_names,
                 )
             )
             compiler.implicit_returning.append(c)
@@ -869,6 +909,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
                     c,
                     compiler.preparer.format_column(c),
                     _create_insert_prefetch_bind_param(compiler, c, **kw),
+                    (c.key,),
                 )
             )
     elif c is stmt.table._autoincrement_column or c.server_default is not None:
@@ -936,6 +977,7 @@ def _append_param_insert_pk_no_returning(compiler, stmt, c, values, kw):
                 c,
                 compiler.preparer.format_column(c),
                 _create_insert_prefetch_bind_param(compiler, c, **kw),
+                (c.key,),
             )
         )
     elif (
@@ -962,11 +1004,17 @@ def _append_param_insert_hasdefault(
         if compiler.dialect.supports_sequences and (
             not c.default.optional or not compiler.dialect.sequences_optional
         ):
+            accumulated_bind_names: Set[str] = set()
             values.append(
                 (
                     c,
                     compiler.preparer.format_column(c),
-                    compiler.process(c.default, **kw),
+                    compiler.process(
+                        c.default,
+                        accumulate_bind_names=accumulated_bind_names,
+                        **kw,
+                    ),
+                    accumulated_bind_names,
                 )
             )
             if implicit_return_defaults and c in implicit_return_defaults:
@@ -974,11 +1022,17 @@ def _append_param_insert_hasdefault(
             elif not c.primary_key:
                 compiler.postfetch.append(c)
     elif c.default.is_clause_element:
+        accumulated_bind_names = set()
         values.append(
             (
                 c,
                 compiler.preparer.format_column(c),
-                compiler.process(c.default.arg.self_group(), **kw),
+                compiler.process(
+                    c.default.arg.self_group(),
+                    accumulate_bind_names=accumulated_bind_names,
+                    **kw,
+                ),
+                accumulated_bind_names,
             )
         )
 
@@ -993,6 +1047,7 @@ def _append_param_insert_hasdefault(
                 c,
                 compiler.preparer.format_column(c),
                 _create_insert_prefetch_bind_param(compiler, c, **kw),
+                (c.key,),
             )
         )
 
@@ -1001,7 +1056,7 @@ def _append_param_insert_select_hasdefault(
     compiler: SQLCompiler,
     stmt: ValuesBase,
     c: ColumnClause[Any],
-    values: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]],
+    values: List[_CrudParamElementSQLExpr],
     kw: Dict[str, Any],
 ) -> None:
 
@@ -1014,6 +1069,7 @@ def _append_param_insert_select_hasdefault(
                     c,
                     compiler.preparer.format_column(c),
                     c.default.next_value(),
+                    (),
                 )
             )
     elif default_is_clause_element(c.default):
@@ -1022,6 +1078,7 @@ def _append_param_insert_select_hasdefault(
                 c,
                 compiler.preparer.format_column(c),
                 c.default.arg.self_group(),
+                (),
             )
         )
     else:
@@ -1032,6 +1089,7 @@ def _append_param_insert_select_hasdefault(
                 _create_insert_prefetch_bind_param(
                     compiler, c, process=False, **kw
                 ),
+                (c.key,),
             )
         )
 
@@ -1051,6 +1109,7 @@ def _append_param_update(
                         use_table=include_table,
                     ),
                     compiler.process(c.onupdate.arg.self_group(), **kw),
+                    (),
                 )
             )
             if implicit_return_defaults and c in implicit_return_defaults:
@@ -1066,6 +1125,7 @@ def _append_param_update(
                         use_table=include_table,
                     ),
                     _create_update_prefetch_bind_param(compiler, c, **kw),
+                    (c.key,),
                 )
             )
     elif c.server_onupdate is not None:
@@ -1177,7 +1237,7 @@ class _multiparam_column(elements.ColumnElement[Any]):
 def _process_multiparam_default_bind(
     compiler: SQLCompiler,
     stmt: ValuesBase,
-    c: ColumnClause[Any],
+    c: KeyedColumnElement[Any],
     index: int,
     kw: Dict[str, Any],
 ) -> str:
@@ -1243,14 +1303,18 @@ def _get_update_multitable_params(
                         name=_col_bind_name(c),
                         **kw,  # TODO: no test coverage for literal binds here
                     )
+                    accumulated_bind_names: Iterable[str] = (c.key,)
                 elif value._is_bind_parameter:
+                    cbn = _col_bind_name(c)
                     value = _handle_values_anonymous_param(
-                        compiler, c, value, name=_col_bind_name(c), **kw
+                        compiler, c, value, name=cbn, **kw
                     )
+                    accumulated_bind_names = (cbn,)
                 else:
                     compiler.postfetch.append(c)
                     value = compiler.process(value.self_group(), **kw)
-                values.append((c, col_value, value))
+                    accumulated_bind_names = ()
+                values.append((c, col_value, value, accumulated_bind_names))
     # determine tables which are actually to be updated - process onupdate
     # and server_onupdate for these
     for t in affected_tables:
@@ -1266,6 +1330,7 @@ def _get_update_multitable_params(
                             compiler.process(
                                 c.onupdate.arg.self_group(), **kw
                             ),
+                            (),
                         )
                     )
                     compiler.postfetch.append(c)
@@ -1277,6 +1342,7 @@ def _get_update_multitable_params(
                             _create_update_prefetch_bind_param(
                                 compiler, c, name=_col_bind_name(c), **kw
                             ),
+                            (c.key,),
                         )
                     )
             elif c.server_onupdate is not None:
@@ -1287,21 +1353,21 @@ def _extend_values_for_multiparams(
     compiler: SQLCompiler,
     stmt: ValuesBase,
     compile_state: DMLState,
-    initial_values: Sequence[Tuple[ColumnClause[Any], str, str]],
+    initial_values: Sequence[_CrudParamElementStr],
     _column_as_key: Callable[..., str],
     kw: Dict[str, Any],
-) -> List[Sequence[Tuple[ColumnClause[Any], str, str]]]:
+) -> List[Sequence[_CrudParamElementStr]]:
     values_0 = initial_values
     values = [initial_values]
 
     mp = compile_state._multi_parameters
     assert mp is not None
     for i, row in enumerate(mp[1:]):
-        extension: List[Tuple[ColumnClause[Any], str, str]] = []
+        extension: List[_CrudParamElementStr] = []
 
         row = {_column_as_key(key): v for key, v in row.items()}
 
-        for (col, col_expr, param) in values_0:
+        for (col, col_expr, param, accumulated_names) in values_0:
             if col.key in row:
                 key = col.key
 
@@ -1320,7 +1386,7 @@ def _extend_values_for_multiparams(
                     compiler, stmt, col, i, kw
                 )
 
-            extension.append((col, col_expr, new_param))
+            extension.append((col, col_expr, new_param, accumulated_names))
 
         values.append(extension)
 
@@ -1366,7 +1432,8 @@ def _get_stmt_parameter_tuples_params(
 
                 v = compiler.process(v.self_group(), **kw)
 
-            values.append((k, col_expr, v))
+            # TODO: not sure if accumulated_bind_names applies here
+            values.append((k, col_expr, v, ()))
 
 
 def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
index 4ce093156d8edb1326c56a139126d3f98f192ecd..429ebf163c72abeb925025bcd450429e4682685f 100644 (file)
@@ -1,6 +1,7 @@
 import itertools
 
 from sqlalchemy import and_
+from sqlalchemy import bindparam
 from sqlalchemy import event
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
@@ -8,6 +9,7 @@ from sqlalchemy import func
 from sqlalchemy import INT
 from sqlalchemy import Integer
 from sqlalchemy import literal
+from sqlalchemy import select
 from sqlalchemy import Sequence
 from sqlalchemy import sql
 from sqlalchemy import String
@@ -741,6 +743,14 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest):
             Column("\u6e2c\u8a66", Integer),
         )
 
+        Table(
+            "extra_table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("x_value", String(50)),
+            Column("y_value", String(50)),
+        )
+
     def test_insert_unicode_keys(self, connection):
         table = self.tables["Unitéble2"]
 
@@ -807,6 +817,88 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest):
 
         eq_(result.inserted_primary_key_rows, [(1,), (2,), (3,)])
 
+    @testing.combinations(True, False, argnames="use_returning")
+    @testing.combinations(1, 2, argnames="num_embedded_params")
+    @testing.combinations(True, False, argnames="use_whereclause")
+    @testing.crashes(
+        "+mariadbconnector",
+        "returning crashes, regular executemany malfunctions",
+    )
+    def test_insert_w_bindparam_in_subq(
+        self, connection, use_returning, num_embedded_params, use_whereclause
+    ):
+        """test #8639"""
+
+        t = self.tables.data
+        extra = self.tables.extra_table
+
+        conn = connection
+        connection.execute(
+            extra.insert(),
+            [
+                {"x_value": "p1", "y_value": "yv1"},
+                {"x_value": "p2", "y_value": "yv2"},
+                {"x_value": "p1_p1", "y_value": "yv3"},
+                {"x_value": "p2_p2", "y_value": "yv4"},
+            ],
+        )
+
+        if num_embedded_params == 1:
+            if use_whereclause:
+                scalar_subq = select(bindparam("paramname")).scalar_subquery()
+                params = [
+                    {"paramname": "p1_p1", "y": "y1"},
+                    {"paramname": "p2_p2", "y": "y2"},
+                ]
+            else:
+                scalar_subq = (
+                    select(extra.c.x_value)
+                    .where(extra.c.y_value == bindparam("y_value"))
+                    .scalar_subquery()
+                )
+                params = [
+                    {"y_value": "yv3", "y": "y1"},
+                    {"y_value": "yv4", "y": "y2"},
+                ]
+
+        elif num_embedded_params == 2:
+            if use_whereclause:
+                scalar_subq = (
+                    select(
+                        bindparam("paramname1", type_=String) + extra.c.x_value
+                    )
+                    .where(extra.c.y_value == bindparam("y_value"))
+                    .scalar_subquery()
+                )
+                params = [
+                    {"paramname1": "p1_", "y_value": "yv1", "y": "y1"},
+                    {"paramname1": "p2_", "y_value": "yv2", "y": "y2"},
+                ]
+            else:
+                scalar_subq = select(
+                    bindparam("paramname1", type_=String)
+                    + bindparam("paramname2", type_=String)
+                ).scalar_subquery()
+                params = [
+                    {"paramname1": "p1_", "paramname2": "p1", "y": "y1"},
+                    {"paramname1": "p2_", "paramname2": "p2", "y": "y2"},
+                ]
+        else:
+            assert False
+
+        stmt = t.insert().values(x=scalar_subq)
+        if use_returning:
+            stmt = stmt.returning(t.c["x", "y"])
+
+        result = conn.execute(stmt, params)
+
+        if use_returning:
+            eq_(result.all(), [("p1_p1", "y1"), ("p2_p2", "y2")])
+
+        result = conn.execute(select(t.c["x", "y"]))
+
+        eq_(result.all(), [("p1_p1", "y1"), ("p2_p2", "y2")])
+
     def test_insert_returning_defaults(self, connection):
         t = self.tables.data
 
index 901be7132a2765923acddc2de0e7350fc0730791..6c717123bd4fc89a873f10663ed14f05aa8fbd04 100644 (file)
@@ -365,6 +365,22 @@ class DerivedTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL):
 
 
 class RoundTripTestBase:
+    @testing.requires.insertmanyvalues
+    def test_insertmanyvalues_returning(self, connection):
+        tt = self.tables.test_table
+        result = connection.execute(
+            tt.insert().returning(tt.c["x", "y"]),
+            [
+                {"x": "X1", "y": "Y1"},
+                {"x": "X2", "y": "Y2"},
+                {"x": "X3", "y": "Y3"},
+            ],
+        )
+        eq_(
+            result.all(),
+            [("X1", "Y1"), ("X2", "Y2"), ("X3", "Y3")],
+        )
+
     def test_round_trip(self, connection):
         connection.execute(
             self.tables.test_table.insert(),