from .elements import BindParameter
from .elements import ColumnClause
from .elements import ColumnElement
- from .elements import KeyedColumnElement
from .elements import Label
from .functions import Function
from .selectable import AliasedReturnsRows
"named": ":%(name)s",
}
+
_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]")
_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__"))
is_default_expr: bool
single_values_expr: str
- insert_crud_params: List[Tuple[KeyedColumnElement[Any], str, str]]
+ insert_crud_params: List[crud._CrudParamElementStr]
num_positional_params_counted: int
skip_bind_expression=False,
literal_execute=False,
render_postcompile=False,
+ accumulate_bind_names=None,
**kwargs,
):
if not skip_bind_expression:
literal_binds=literal_binds,
literal_execute=literal_execute,
render_postcompile=render_postcompile,
+ accumulate_bind_names=accumulate_bind_names,
**kwargs,
)
if bindparam.expanding:
self.binds[bindparam.key] = self.binds[name] = bindparam
+ if accumulate_bind_names is not None:
+ accumulate_bind_names.add(name)
+
# if we are given a cache key that we're going to match against,
# relate the bindparam here to one that is most likely present
# in the "extracted params" portion of the cache key. this is used
all_keys = set(parameters[0])
- escaped_insert_crud_params: Sequence[Any] = [
- (escaped_bind_names.get(col.key, col.key), formatted)
- for col, _, formatted in insert_crud_params
- ]
+ def apply_placeholders(keys, formatted):
+ for key in keys:
+ key = escaped_bind_names.get(key, key)
+ formatted = formatted.replace(
+ self.bindtemplate % {"name": key},
+ self.bindtemplate
+ % {"name": f"{key}__EXECMANY_INDEX__"},
+ )
+ return formatted
+
+ formatted_values_clause = f"""({', '.join(
+ apply_placeholders(bind_keys, formatted)
+ for _, _, formatted, bind_keys in insert_crud_params
+ )})"""
keys_to_replace = all_keys.intersection(
- key for key, _ in escaped_insert_crud_params
+ escaped_bind_names.get(key, key)
+ for _, _, _, bind_keys in insert_crud_params
+ for key in bind_keys
)
base_parameters = {
key: parameters[0][key]
}
executemany_values_w_comma = ""
else:
- escaped_insert_crud_params = ()
+ formatted_values_clause = ""
keys_to_replace = set()
base_parameters = {}
executemany_values_w_comma = f"({imv.single_values_expr}), "
replaced_parameters = base_parameters.copy()
for i, param in enumerate(batch):
- new_tokens = [
- formatted.replace(key, f"{key}__{i}")
- if key in param
- else formatted
- for key, formatted in escaped_insert_crud_params
- ]
replaced_values_clauses.append(
- f"({', '.join(new_tokens)})"
+ formatted_values_clause.replace(
+ "EXECMANY_INDEX__", str(i)
+ )
)
replaced_parameters.update(
if crud_params_single or not supports_default_values:
text += " (%s)" % ", ".join(
- [expr for _, expr, _ in crud_params_single]
+ [expr for _, expr, _, _ in crud_params_single]
)
if self.implicit_returning or insert_stmt._returning:
True,
self.dialect.default_metavalue_token,
cast(
- "List[Tuple[KeyedColumnElement[Any], str, str]]",
- crud_params_single,
+ "List[crud._CrudParamElementStr]", crud_params_single
),
(positiontup_after - positiontup_before),
)
text += " VALUES %s" % (
", ".join(
"(%s)"
- % (", ".join(value for _, _, value in crud_param_set))
+ % (", ".join(value for _, _, value, _ in crud_param_set))
for crud_param_set in crud_params_struct.all_multi_params
)
)
insert_single_values_expr = ", ".join(
[
value
- for _, _, value in cast(
- "List[Tuple[Any, Any, str]]", crud_params_single
+ for _, _, value, _ in cast(
+ "List[crud._CrudParamElementStr]",
+ crud_params_single,
)
]
)
False,
insert_single_values_expr,
cast(
- "List[Tuple[KeyedColumnElement[Any], str, str]]",
+ "List[crud._CrudParamElementStr]",
crud_params_single,
),
positiontup_after - positiontup_before,
text += " SET "
text += ", ".join(
expr + "=" + value
- for _, expr, value in cast(
- "List[Tuple[Any, str, str]]", crud_params
+ for _, expr, value, _ in cast(
+ "List[Tuple[Any, str, str, Any]]", crud_params
)
)
from typing import Callable
from typing import cast
from typing import Dict
+from typing import Iterable
from typing import List
from typing import MutableMapping
from typing import NamedTuple
from typing import Optional
from typing import overload
from typing import Sequence
+from typing import Set
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from .dml import DMLState
from .dml import ValuesBase
from .elements import ColumnElement
+ from .elements import KeyedColumnElement
from .schema import _SQLExprDefault
REQUIRED = util.symbol(
return c
-_CrudParamSequence = Sequence[
- Tuple[
- "ColumnElement[Any]",
- str,
- Optional[Union[str, "_SQLExprDefault"]],
- ]
+_CrudParamElement = Tuple[
+ "ColumnElement[Any]",
+ str,
+ Optional[Union[str, "_SQLExprDefault"]],
+ Iterable[str],
+]
+_CrudParamElementStr = Tuple[
+ "KeyedColumnElement[Any]",
+ str,
+ str,
+ Iterable[str],
]
+_CrudParamElementSQLExpr = Tuple[
+ "ColumnClause[Any]",
+ str,
+ "_SQLExprDefault",
+ Iterable[str],
+]
+
+_CrudParamSequence = List[_CrudParamElement]
class _CrudParams(NamedTuple):
single_params: _CrudParamSequence
- all_multi_params: List[Sequence[Tuple[ColumnClause[Any], str, str]]]
+
+ all_multi_params: List[Sequence[_CrudParamElementStr]]
def _get_crud_params(
c,
compiler.preparer.format_column(c),
_create_bind_param(compiler, c, None, required=True),
+ (c.key,),
)
for c in stmt.table.columns
],
)
# create a list of column assignment clauses as tuples
- values: List[
- Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]]
- ] = []
+ values: List[_CrudParamElement] = []
if stmt_parameter_tuples is not None:
_get_stmt_parameter_tuples_params(
compiler,
stmt,
compile_state,
- cast("Sequence[Tuple[ColumnClause[Any], str, str]]", values),
+ cast(
+ "Sequence[_CrudParamElementStr]",
+ values,
+ ),
cast("Callable[..., str]", _column_as_key),
kw,
)
_as_dml_column(stmt.table.columns[0]),
compiler.preparer.format_column(stmt.table.columns[0]),
compiler.dialect.default_metavalue_token,
+ (),
)
]
compiler.stack[-1]["insert_from_select"] = stmt.select
- add_select_cols: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]] = []
+ add_select_cols: List[_CrudParamElementSQLExpr] = []
if stmt.include_insert_from_select_defaults:
col_set = set(cols)
for col in stmt.table.columns:
col_key = _getattr_col_key(c)
if col_key in parameters and col_key not in check_columns:
parameters.pop(col_key)
- values.append((c, compiler.preparer.format_column(c), None))
+ values.append((c, compiler.preparer.format_column(c), None, ()))
else:
_append_param_insert_select_hasdefault(
compiler, stmt, c, add_select_cols, kw
f"Can't extend statement for INSERT..FROM SELECT to include "
f"additional default-holding column(s) "
f"""{
- ', '.join(repr(key) for _, key, _ in add_select_cols)
+ ', '.join(repr(key) for _, key, _, _ in add_select_cols)
}. Convert the selectable to a subquery() first, or pass """
"include_defaults=False to Insert.from_select() to skip these "
"columns."
ins_from_select = ins_from_select._generate()
# copy raw_columns
ins_from_select._raw_columns = list(ins_from_select._raw_columns) + [
- expr for _, _, expr in add_select_cols
+ expr for _, _, expr, _ in add_select_cols
]
compiler.stack[-1]["insert_from_select"] = ins_from_select
c, use_table=compile_state.include_table_with_column_exprs
)
+ accumulated_bind_names: Set[str] = set()
+
if coercions._is_literal(value):
if (
if not _compile_state_isinsert(compile_state)
or not compile_state._has_multi_parameters
else "%s_m0" % _col_bind_name(c),
+ accumulate_bind_names=accumulated_bind_names,
**kw,
)
elif value._is_bind_parameter:
if not _compile_state_isinsert(compile_state)
or not compile_state._has_multi_parameters
else "%s_m0" % _col_bind_name(c),
+ accumulate_bind_names=accumulated_bind_names,
**kw,
)
else:
# value is a SQL expression
- value = compiler.process(value.self_group(), **kw)
+ value = compiler.process(
+ value.self_group(),
+ accumulate_bind_names=accumulated_bind_names,
+ **kw,
+ )
if compile_state.isupdate:
if implicit_return_defaults and c in implicit_return_defaults:
compiler.postfetch.append(c)
- values.append((c, col_value, value))
+ values.append((c, col_value, value, accumulated_bind_names))
def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
not c.default.optional
or not compiler.dialect.sequences_optional
):
+ accumulated_bind_names: Set[str] = set()
values.append(
(
c,
compiler.preparer.format_column(c),
- compiler.process(c.default, **kw),
+ compiler.process(
+ c.default,
+ accumulate_bind_names=accumulated_bind_names,
+ **kw,
+ ),
+ accumulated_bind_names,
)
)
compiler.implicit_returning.append(c)
elif c.default.is_clause_element:
+ accumulated_bind_names = set()
values.append(
(
c,
compiler.preparer.format_column(c),
- compiler.process(c.default.arg.self_group(), **kw),
+ compiler.process(
+ c.default.arg.self_group(),
+ accumulate_bind_names=accumulated_bind_names,
+ **kw,
+ ),
+ accumulated_bind_names,
)
)
compiler.implicit_returning.append(c)
c,
compiler.preparer.format_column(c),
_create_insert_prefetch_bind_param(compiler, c, **kw),
+ (c.key,),
)
)
elif c is stmt.table._autoincrement_column or c.server_default is not None:
c,
compiler.preparer.format_column(c),
_create_insert_prefetch_bind_param(compiler, c, **kw),
+ (c.key,),
)
)
elif (
if compiler.dialect.supports_sequences and (
not c.default.optional or not compiler.dialect.sequences_optional
):
+ accumulated_bind_names: Set[str] = set()
values.append(
(
c,
compiler.preparer.format_column(c),
- compiler.process(c.default, **kw),
+ compiler.process(
+ c.default,
+ accumulate_bind_names=accumulated_bind_names,
+ **kw,
+ ),
+ accumulated_bind_names,
)
)
if implicit_return_defaults and c in implicit_return_defaults:
elif not c.primary_key:
compiler.postfetch.append(c)
elif c.default.is_clause_element:
+ accumulated_bind_names = set()
values.append(
(
c,
compiler.preparer.format_column(c),
- compiler.process(c.default.arg.self_group(), **kw),
+ compiler.process(
+ c.default.arg.self_group(),
+ accumulate_bind_names=accumulated_bind_names,
+ **kw,
+ ),
+ accumulated_bind_names,
)
)
c,
compiler.preparer.format_column(c),
_create_insert_prefetch_bind_param(compiler, c, **kw),
+ (c.key,),
)
)
compiler: SQLCompiler,
stmt: ValuesBase,
c: ColumnClause[Any],
- values: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]],
+ values: List[_CrudParamElementSQLExpr],
kw: Dict[str, Any],
) -> None:
c,
compiler.preparer.format_column(c),
c.default.next_value(),
+ (),
)
)
elif default_is_clause_element(c.default):
c,
compiler.preparer.format_column(c),
c.default.arg.self_group(),
+ (),
)
)
else:
_create_insert_prefetch_bind_param(
compiler, c, process=False, **kw
),
+ (c.key,),
)
)
use_table=include_table,
),
compiler.process(c.onupdate.arg.self_group(), **kw),
+ (),
)
)
if implicit_return_defaults and c in implicit_return_defaults:
use_table=include_table,
),
_create_update_prefetch_bind_param(compiler, c, **kw),
+ (c.key,),
)
)
elif c.server_onupdate is not None:
def _process_multiparam_default_bind(
compiler: SQLCompiler,
stmt: ValuesBase,
- c: ColumnClause[Any],
+ c: KeyedColumnElement[Any],
index: int,
kw: Dict[str, Any],
) -> str:
name=_col_bind_name(c),
**kw, # TODO: no test coverage for literal binds here
)
+ accumulated_bind_names: Iterable[str] = (c.key,)
elif value._is_bind_parameter:
+ cbn = _col_bind_name(c)
value = _handle_values_anonymous_param(
- compiler, c, value, name=_col_bind_name(c), **kw
+ compiler, c, value, name=cbn, **kw
)
+ accumulated_bind_names = (cbn,)
else:
compiler.postfetch.append(c)
value = compiler.process(value.self_group(), **kw)
- values.append((c, col_value, value))
+ accumulated_bind_names = ()
+ values.append((c, col_value, value, accumulated_bind_names))
# determine tables which are actually to be updated - process onupdate
# and server_onupdate for these
for t in affected_tables:
compiler.process(
c.onupdate.arg.self_group(), **kw
),
+ (),
)
)
compiler.postfetch.append(c)
_create_update_prefetch_bind_param(
compiler, c, name=_col_bind_name(c), **kw
),
+ (c.key,),
)
)
elif c.server_onupdate is not None:
compiler: SQLCompiler,
stmt: ValuesBase,
compile_state: DMLState,
- initial_values: Sequence[Tuple[ColumnClause[Any], str, str]],
+ initial_values: Sequence[_CrudParamElementStr],
_column_as_key: Callable[..., str],
kw: Dict[str, Any],
-) -> List[Sequence[Tuple[ColumnClause[Any], str, str]]]:
+) -> List[Sequence[_CrudParamElementStr]]:
values_0 = initial_values
values = [initial_values]
mp = compile_state._multi_parameters
assert mp is not None
for i, row in enumerate(mp[1:]):
- extension: List[Tuple[ColumnClause[Any], str, str]] = []
+ extension: List[_CrudParamElementStr] = []
row = {_column_as_key(key): v for key, v in row.items()}
- for (col, col_expr, param) in values_0:
+ for (col, col_expr, param, accumulated_names) in values_0:
if col.key in row:
key = col.key
compiler, stmt, col, i, kw
)
- extension.append((col, col_expr, new_param))
+ extension.append((col, col_expr, new_param, accumulated_names))
values.append(extension)
v = compiler.process(v.self_group(), **kw)
- values.append((k, col_expr, v))
+ # TODO: not sure if accumulated_bind_names applies here
+ values.append((k, col_expr, v, ()))
def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
import itertools
from sqlalchemy import and_
+from sqlalchemy import bindparam
from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import INT
from sqlalchemy import Integer
from sqlalchemy import literal
+from sqlalchemy import select
from sqlalchemy import Sequence
from sqlalchemy import sql
from sqlalchemy import String
Column("\u6e2c\u8a66", Integer),
)
+ Table(
+ "extra_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x_value", String(50)),
+ Column("y_value", String(50)),
+ )
+
def test_insert_unicode_keys(self, connection):
table = self.tables["Unitéble2"]
eq_(result.inserted_primary_key_rows, [(1,), (2,), (3,)])
+ @testing.combinations(True, False, argnames="use_returning")
+ @testing.combinations(1, 2, argnames="num_embedded_params")
+ @testing.combinations(True, False, argnames="use_whereclause")
+ @testing.crashes(
+ "+mariadbconnector",
+ "returning crashes, regular executemany malfunctions",
+ )
+ def test_insert_w_bindparam_in_subq(
+ self, connection, use_returning, num_embedded_params, use_whereclause
+ ):
+ """test #8639"""
+
+ t = self.tables.data
+ extra = self.tables.extra_table
+
+ conn = connection
+ connection.execute(
+ extra.insert(),
+ [
+ {"x_value": "p1", "y_value": "yv1"},
+ {"x_value": "p2", "y_value": "yv2"},
+ {"x_value": "p1_p1", "y_value": "yv3"},
+ {"x_value": "p2_p2", "y_value": "yv4"},
+ ],
+ )
+
+ if num_embedded_params == 1:
+ if use_whereclause:
+ scalar_subq = select(bindparam("paramname")).scalar_subquery()
+ params = [
+ {"paramname": "p1_p1", "y": "y1"},
+ {"paramname": "p2_p2", "y": "y2"},
+ ]
+ else:
+ scalar_subq = (
+ select(extra.c.x_value)
+ .where(extra.c.y_value == bindparam("y_value"))
+ .scalar_subquery()
+ )
+ params = [
+ {"y_value": "yv3", "y": "y1"},
+ {"y_value": "yv4", "y": "y2"},
+ ]
+
+ elif num_embedded_params == 2:
+ if use_whereclause:
+ scalar_subq = (
+ select(
+ bindparam("paramname1", type_=String) + extra.c.x_value
+ )
+ .where(extra.c.y_value == bindparam("y_value"))
+ .scalar_subquery()
+ )
+ params = [
+ {"paramname1": "p1_", "y_value": "yv1", "y": "y1"},
+ {"paramname1": "p2_", "y_value": "yv2", "y": "y2"},
+ ]
+ else:
+ scalar_subq = select(
+ bindparam("paramname1", type_=String)
+ + bindparam("paramname2", type_=String)
+ ).scalar_subquery()
+ params = [
+ {"paramname1": "p1_", "paramname2": "p1", "y": "y1"},
+ {"paramname1": "p2_", "paramname2": "p2", "y": "y2"},
+ ]
+ else:
+ assert False
+
+ stmt = t.insert().values(x=scalar_subq)
+ if use_returning:
+ stmt = stmt.returning(t.c["x", "y"])
+
+ result = conn.execute(stmt, params)
+
+ if use_returning:
+ eq_(result.all(), [("p1_p1", "y1"), ("p2_p2", "y2")])
+
+ result = conn.execute(select(t.c["x", "y"]))
+
+ eq_(result.all(), [("p1_p1", "y1"), ("p2_p2", "y2")])
+
def test_insert_returning_defaults(self, connection):
t = self.tables.data