self.positiontup = list(param_pos)
if self.escaped_bind_names:
- reverse_escape = {v: k for k, v in self.escaped_bind_names.items()}
- assert len(self.escaped_bind_names) == len(reverse_escape)
+ len_before = len(param_pos)
param_pos = {
self.escaped_bind_names.get(name, name): pos
for name, pos in param_pos.items()
}
+ assert len(param_pos) == len_before
# Can't use format here since % chars are not escaped.
self.string = self._pyformat_pattern.sub(
skip_bind_expression=False,
literal_execute=False,
render_postcompile=False,
- accumulate_bind_names=None,
**kwargs,
):
if not skip_bind_expression:
literal_binds=literal_binds and not bindparam.expanding,
literal_execute=literal_execute,
render_postcompile=render_postcompile,
- accumulate_bind_names=accumulate_bind_names,
**kwargs,
)
if bindparam.expanding:
self.binds[bindparam.key] = self.binds[name] = bindparam
- if accumulate_bind_names is not None:
- accumulate_bind_names.add(name)
-
# if we are given a cache key that we're going to match against,
# relate the bindparam here to one that is most likely present
# in the "extracted params" portion of the cache key. this is used
expanding: bool = False,
escaped_from: Optional[str] = None,
bindparam_type: Optional[TypeEngine[Any]] = None,
+ accumulate_bind_names: Optional[Set[str]] = None,
+ visited_bindparam: Optional[List[str]] = None,
**kw: Any,
) -> str:
- if self._visited_bindparam is not None:
- self._visited_bindparam.append(name)
+ # TODO: accumulate_bind_names is passed by crud.py to gather
+ # names on a per-value basis, visited_bindparam is passed by
+ # visit_insert() to collect all parameters in the statement.
+ # see if this gathering can be simplified somehow
+ if accumulate_bind_names is not None:
+ accumulate_bind_names.add(name)
+ if visited_bindparam is not None:
+ visited_bindparam.append(name)
if not escaped_from:
assert insert_crud_params is not None
escaped_bind_names: Mapping[str, str]
+ expand_pos_lower_index = expand_pos_upper_index = 0
+
if not self.positional:
if self.escaped_bind_names:
escaped_bind_names = self.escaped_bind_names
keys_to_replace = set()
base_parameters = {}
executemany_values_w_comma = f"({imv.single_values_expr}), "
+
+ all_names_we_will_expand: Set[str] = set()
+ for elem in imv.insert_crud_params:
+ all_names_we_will_expand.update(elem[3])
+
+ # get the start and end position in a particular list
+ # of parameters where we will be doing the "expanding".
+ # statements can have params on either side or both sides,
+ # given RETURNING and CTEs
+ if all_names_we_will_expand:
+ positiontup = self.positiontup
+ assert positiontup is not None
+
+ all_expand_positions = {
+ idx
+ for idx, name in enumerate(positiontup)
+ if name in all_names_we_will_expand
+ }
+ expand_pos_lower_index = min(all_expand_positions)
+ expand_pos_upper_index = max(all_expand_positions) + 1
+ assert (
+ len(all_expand_positions)
+ == expand_pos_upper_index - expand_pos_lower_index
+ )
+
if self._numeric_binds:
escaped = re.escape(self._numeric_binds_identifier_char)
executemany_values_w_comma = re.sub(
replaced_parameters: Any
if self.positional:
- # the assumption here is that any parameters that are not
- # in the VALUES clause are expected to be parameterized
- # expressions in the RETURNING (or maybe ON CONFLICT) clause.
- # So based on
- # which sequence comes first in the compiler's INSERT
- # statement tells us where to expand the parameters.
-
- # otherwise we probably shouldn't be doing insertmanyvalues
- # on the statement.
-
num_ins_params = imv.num_positional_params_counted
batch_iterator: Iterable[Tuple[Any, ...]]
if num_ins_params == len(batch[0]):
- extra_params = ()
+ extra_params_left = extra_params_right = ()
batch_iterator = batch
- elif self.returning_precedes_values or self._numeric_binds:
- extra_params = batch[0][:-num_ins_params]
- batch_iterator = (b[-num_ins_params:] for b in batch)
else:
- extra_params = batch[0][num_ins_params:]
- batch_iterator = (b[:num_ins_params] for b in batch)
+ extra_params_left = batch[0][:expand_pos_lower_index]
+ extra_params_right = batch[0][expand_pos_upper_index:]
+ batch_iterator = (
+ b[expand_pos_lower_index:expand_pos_upper_index]
+ for b in batch
+ )
+
+ expanded_values_string = (
+ executemany_values_w_comma * len(batch)
+ )[:-2]
- values_string = (executemany_values_w_comma * len(batch))[:-2]
if self._numeric_binds and num_ins_params > 0:
+ # numeric will always number the parameters inside of
+ # VALUES (and thus order self.positiontup) to be higher
+ # than non-VALUES parameters, no matter where in the
+ # statement those non-VALUES parameters appear (this is
+ # ensured in _process_numeric by numbering first all
+ # params that are not in _values_bindparam)
+ # therefore all extra params are always
+ # on the left side and numbered lower than the VALUES
+ # parameters
+ assert not extra_params_right
+
+ start = expand_pos_lower_index + 1
+ end = num_ins_params * (len(batch)) + start
+
# need to format here, since statement may contain
# unescaped %, while values_string contains just (%s, %s)
- start = len(extra_params) + 1
- end = num_ins_params * len(batch) + start
positions = tuple(
f"{self._numeric_binds_identifier_char}{i}"
for i in range(start, end)
)
- values_string = values_string % positions
+ expanded_values_string = expanded_values_string % positions
replaced_statement = statement.replace(
- "__EXECMANY_TOKEN__", values_string
+ "__EXECMANY_TOKEN__", expanded_values_string
)
replaced_parameters = tuple(
itertools.chain.from_iterable(batch_iterator)
)
- if self.returning_precedes_values or self._numeric_binds:
- replaced_parameters = extra_params + replaced_parameters
- else:
- replaced_parameters = replaced_parameters + extra_params
+
+ replaced_parameters = (
+ extra_params_left
+ + replaced_parameters
+ + extra_params_right
+ )
+
else:
replaced_values_clauses = []
replaced_parameters = base_parameters.copy()
)
batchnum += 1
- def visit_insert(self, insert_stmt, **kw):
+ def visit_insert(self, insert_stmt, visited_bindparam=None, **kw):
compile_state = insert_stmt._compile_state_factory(
insert_stmt, self, **kw
counted_bindparam = 0
+ # reset any incoming "visited_bindparam" collection
+ visited_bindparam = None
+
# for positional, insertmanyvalues needs to know how many
# bound parameters are in the VALUES sequence; there's no simple
# rule because default expressions etc. can have zero or more
# this very simplistic "count after" works and is
# likely the least amount of callcounts, though looks clumsy
if self.positional:
- self._visited_bindparam = []
+ # if we are inside a CTE, don't count parameters
+ # here since they wont be for insertmanyvalues. keep
+ # visited_bindparam at None so no counting happens.
+ # see #9173
+ has_visiting_cte = "visiting_cte" in kw
+ if not has_visiting_cte:
+ visited_bindparam = []
crud_params_struct = crud._get_crud_params(
- self, insert_stmt, compile_state, toplevel, **kw
+ self,
+ insert_stmt,
+ compile_state,
+ toplevel,
+ visited_bindparam=visited_bindparam,
+ **kw,
)
- if self.positional:
- assert self._visited_bindparam is not None
- counted_bindparam = len(self._visited_bindparam)
+ if self.positional and visited_bindparam is not None:
+ counted_bindparam = len(visited_bindparam)
if self._numeric_binds:
if self._values_bindparam is not None:
- self._values_bindparam += self._visited_bindparam
+ self._values_bindparam += visited_bindparam
else:
- self._values_bindparam = self._visited_bindparam
- self._visited_bindparam = None
+ self._values_bindparam = visited_bindparam
crud_params_single = crud_params_struct.single_params
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing import mock
+from sqlalchemy.testing import provision
from sqlalchemy.testing.provision import normalize_sequence
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
eq_(result.inserted_primary_key_rows, [(1,), (2,), (3,)])
+ @testing.requires.ctes_on_dml
+ @testing.variation("add_expr_returning", [True, False])
+ def test_insert_w_bindparam_in_nested_insert(
+ self, connection, add_expr_returning
+ ):
+ """test related to #9173"""
+
+ data, extra_table = self.tables("data", "extra_table")
+
+ inst = (
+ extra_table.insert()
+ .values(x_value="x", y_value="y")
+ .returning(extra_table.c.id)
+ .cte("inst")
+ )
+
+ stmt = (
+ data.insert()
+ .values(x="the x", z=select(inst.c.id).scalar_subquery())
+ .add_cte(inst)
+ )
+
+ if add_expr_returning:
+ stmt = stmt.returning(data.c.id, data.c.y + " returned y")
+ else:
+ stmt = stmt.returning(data.c.id)
+
+ result = connection.execute(
+ stmt,
+ [
+ {"y": "y1"},
+ {"y": "y2"},
+ {"y": "y3"},
+ ],
+ )
+
+ result_rows = result.all()
+
+ ids = [row[0] for row in result_rows]
+
+ extra_row = connection.execute(
+ select(extra_table).order_by(extra_table.c.id)
+ ).one()
+ extra_row_id = extra_row[0]
+ eq_(extra_row, (extra_row_id, "x", "y"))
+ eq_(
+ connection.execute(select(data).order_by(data.c.id)).all(),
+ [
+ (ids[0], "the x", "y1", extra_row_id),
+ (ids[1], "the x", "y2", extra_row_id),
+ (ids[2], "the x", "y3", extra_row_id),
+ ],
+ )
+
+ @testing.requires.provisioned_upsert
+ def test_upsert_w_returning(self, connection):
+ """test cases that will execise SQL similar to that of
+ test/orm/dml/test_bulk_statements.py
+
+ """
+
+ data = self.tables.data
+
+ initial_data = [
+ {"x": "x1", "y": "y1", "z": 4},
+ {"x": "x2", "y": "y2", "z": 8},
+ ]
+ ids = connection.scalars(
+ data.insert().returning(data.c.id), initial_data
+ ).all()
+
+ upsert_data = [
+ {
+ "id": ids[0],
+ "x": "x1",
+ "y": "y1",
+ },
+ {
+ "id": 32,
+ "x": "x19",
+ "y": "y7",
+ },
+ {
+ "id": ids[1],
+ "x": "x5",
+ "y": "y6",
+ },
+ {
+ "id": 28,
+ "x": "x9",
+ "y": "y15",
+ },
+ ]
+
+ stmt = provision.upsert(
+ config,
+ data,
+ (data,),
+ lambda inserted: {"x": inserted.x + " upserted"},
+ )
+
+ result = connection.execute(stmt, upsert_data)
+
+ eq_(
+ result.all(),
+ [
+ (ids[0], "x1 upserted", "y1", 4),
+ (32, "x19", "y7", 5),
+ (ids[1], "x5 upserted", "y2", 8),
+ (28, "x9", "y15", 5),
+ ],
+ )
+
@testing.combinations(True, False, argnames="use_returning")
@testing.combinations(1, 2, argnames="num_embedded_params")
@testing.combinations(True, False, argnames="use_whereclause")
def test_insert_w_bindparam_in_subq(
self, connection, use_returning, num_embedded_params, use_whereclause
):
- """test #8639"""
+ """test #8639
+
+ see also test_insert_w_bindparam_in_nested_insert
+
+ """
t = self.tables.data
extra = self.tables.extra_table