From: Mike Bayer Date: Sat, 15 Oct 2022 19:20:21 +0000 (-0400) Subject: accommodate arbitrary embedded params in insertmanyvalues X-Git-Tag: rel_2_0_0b2~20 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2b966de4196c8271934769337780f7d504d431cf;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git accommodate arbitrary embedded params in insertmanyvalues Fixed bug in new "insertmanyvalues" feature where INSERT that included a subquery with :func:`_sql.bindparam` inside of it would fail to render correctly in "insertmanyvalues" format. This affected psycopg2 most directly as "insertmanyvalues" is used unconditionally with this driver. Fixes: #8639 Change-Id: I67903fa86afe208899d4f23f940e0727d1be2ce3 --- diff --git a/doc/build/changelog/unreleased_20/8639.rst b/doc/build/changelog/unreleased_20/8639.rst new file mode 100644 index 0000000000..46cce757d9 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8639.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, regression, sql + :tickets: 8639 + + Fixed bug in new "insertmanyvalues" feature where INSERT that included a + subquery with :func:`_sql.bindparam` inside of it would fail to render + correctly in "insertmanyvalues" format. This affected psycopg2 most + directly as "insertmanyvalues" is used unconditionally with this driver. + diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index dd40bfe345..efe0ea2b43 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -94,7 +94,6 @@ if typing.TYPE_CHECKING: from .elements import BindParameter from .elements import ColumnClause from .elements import ColumnElement - from .elements import KeyedColumnElement from .elements import Label from .functions import Function from .selectable import AliasedReturnsRows @@ -236,6 +235,7 @@ BIND_TEMPLATES = { "named": ":%(name)s", } + _BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]") _BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__")) @@ -416,7 +416,7 @@ class _InsertManyValues(NamedTuple): is_default_expr: bool single_values_expr: str - insert_crud_params: List[Tuple[KeyedColumnElement[Any], str, str]] + insert_crud_params: List[crud._CrudParamElementStr] num_positional_params_counted: int @@ -2960,6 +2960,7 @@ class SQLCompiler(Compiled): skip_bind_expression=False, literal_execute=False, render_postcompile=False, + accumulate_bind_names=None, **kwargs, ): if not skip_bind_expression: @@ -2973,6 +2974,7 @@ class SQLCompiler(Compiled): literal_binds=literal_binds, literal_execute=literal_execute, render_postcompile=render_postcompile, + accumulate_bind_names=accumulate_bind_names, **kwargs, ) if bindparam.expanding: @@ -3063,6 +3065,9 @@ class SQLCompiler(Compiled): self.binds[bindparam.key] = self.binds[name] = bindparam + if accumulate_bind_names is not None: + accumulate_bind_names.add(name) + # if we are given a cache key that we're going to match against, # relate the bindparam here to one that is most likely present # in the "extracted params" portion of the cache key. this is used @@ -4646,13 +4651,25 @@ class SQLCompiler(Compiled): all_keys = set(parameters[0]) - escaped_insert_crud_params: Sequence[Any] = [ - (escaped_bind_names.get(col.key, col.key), formatted) - for col, _, formatted in insert_crud_params - ] + def apply_placeholders(keys, formatted): + for key in keys: + key = escaped_bind_names.get(key, key) + formatted = formatted.replace( + self.bindtemplate % {"name": key}, + self.bindtemplate + % {"name": f"{key}__EXECMANY_INDEX__"}, + ) + return formatted + + formatted_values_clause = f"""({', '.join( + apply_placeholders(bind_keys, formatted) + for _, _, formatted, bind_keys in insert_crud_params + )})""" keys_to_replace = all_keys.intersection( - key for key, _ in escaped_insert_crud_params + escaped_bind_names.get(key, key) + for _, _, _, bind_keys in insert_crud_params + for key in bind_keys ) base_parameters = { key: parameters[0][key] @@ -4660,7 +4677,7 @@ class SQLCompiler(Compiled): } executemany_values_w_comma = "" else: - escaped_insert_crud_params = () + formatted_values_clause = "" keys_to_replace = set() base_parameters = {} executemany_values_w_comma = f"({imv.single_values_expr}), " @@ -4723,14 +4740,10 @@ class SQLCompiler(Compiled): replaced_parameters = base_parameters.copy() for i, param in enumerate(batch): - new_tokens = [ - formatted.replace(key, f"{key}__{i}") - if key in param - else formatted - for key, formatted in escaped_insert_crud_params - ] replaced_values_clauses.append( - f"({', '.join(new_tokens)})" + formatted_values_clause.replace( + "EXECMANY_INDEX__", str(i) + ) ) replaced_parameters.update( @@ -4841,7 +4854,7 @@ class SQLCompiler(Compiled): if crud_params_single or not supports_default_values: text += " (%s)" % ", ".join( - [expr for _, expr, _ in crud_params_single] + [expr for _, expr, _, _ in crud_params_single] ) if self.implicit_returning or insert_stmt._returning: @@ -4902,8 +4915,7 @@ class SQLCompiler(Compiled): True, self.dialect.default_metavalue_token, cast( - "List[Tuple[KeyedColumnElement[Any], str, str]]", - crud_params_single, + "List[crud._CrudParamElementStr]", crud_params_single ), (positiontup_after - positiontup_before), ) @@ -4911,7 +4923,7 @@ class SQLCompiler(Compiled): text += " VALUES %s" % ( ", ".join( "(%s)" - % (", ".join(value for _, _, value in crud_param_set)) + % (", ".join(value for _, _, value, _ in crud_param_set)) for crud_param_set in crud_params_struct.all_multi_params ) ) @@ -4921,8 +4933,9 @@ class SQLCompiler(Compiled): insert_single_values_expr = ", ".join( [ value - for _, _, value in cast( - "List[Tuple[Any, Any, str]]", crud_params_single + for _, _, value, _ in cast( + "List[crud._CrudParamElementStr]", + crud_params_single, ) ] ) @@ -4935,7 +4948,7 @@ class SQLCompiler(Compiled): False, insert_single_values_expr, cast( - "List[Tuple[KeyedColumnElement[Any], str, str]]", + "List[crud._CrudParamElementStr]", crud_params_single, ), positiontup_after - positiontup_before, @@ -5058,8 +5071,8 @@ class SQLCompiler(Compiled): text += " SET " text += ", ".join( expr + "=" + value - for _, expr, value in cast( - "List[Tuple[Any, str, str]]", crud_params + for _, expr, value, _ in cast( + "List[Tuple[Any, str, str, Any]]", crud_params ) ) diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 22fffb73a1..31d127c2c3 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -18,12 +18,14 @@ from typing import Any from typing import Callable from typing import cast from typing import Dict +from typing import Iterable from typing import List from typing import MutableMapping from typing import NamedTuple from typing import Optional from typing import overload from typing import Sequence +from typing import Set from typing import Tuple from typing import TYPE_CHECKING from typing import Union @@ -49,6 +51,7 @@ if TYPE_CHECKING: from .dml import DMLState from .dml import ValuesBase from .elements import ColumnElement + from .elements import KeyedColumnElement from .schema import _SQLExprDefault REQUIRED = util.symbol( @@ -74,18 +77,32 @@ def _as_dml_column(c: ColumnElement[Any]) -> ColumnClause[Any]: return c -_CrudParamSequence = Sequence[ - Tuple[ - "ColumnElement[Any]", - str, - Optional[Union[str, "_SQLExprDefault"]], - ] +_CrudParamElement = Tuple[ + "ColumnElement[Any]", + str, + Optional[Union[str, "_SQLExprDefault"]], + Iterable[str], +] +_CrudParamElementStr = Tuple[ + "KeyedColumnElement[Any]", + str, + str, + Iterable[str], ] +_CrudParamElementSQLExpr = Tuple[ + "ColumnClause[Any]", + str, + "_SQLExprDefault", + Iterable[str], +] + +_CrudParamSequence = List[_CrudParamElement] class _CrudParams(NamedTuple): single_params: _CrudParamSequence - all_multi_params: List[Sequence[Tuple[ColumnClause[Any], str, str]]] + + all_multi_params: List[Sequence[_CrudParamElementStr]] def _get_crud_params( @@ -175,6 +192,7 @@ def _get_crud_params( c, compiler.preparer.format_column(c), _create_bind_param(compiler, c, None, required=True), + (c.key,), ) for c in stmt.table.columns ], @@ -220,9 +238,7 @@ def _get_crud_params( ) # create a list of column assignment clauses as tuples - values: List[ - Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]] - ] = [] + values: List[_CrudParamElement] = [] if stmt_parameter_tuples is not None: _get_stmt_parameter_tuples_params( @@ -307,7 +323,10 @@ def _get_crud_params( compiler, stmt, compile_state, - cast("Sequence[Tuple[ColumnClause[Any], str, str]]", values), + cast( + "Sequence[_CrudParamElementStr]", + values, + ), cast("Callable[..., str]", _column_as_key), kw, ) @@ -326,6 +345,7 @@ def _get_crud_params( _as_dml_column(stmt.table.columns[0]), compiler.preparer.format_column(stmt.table.columns[0]), compiler.dialect.default_metavalue_token, + (), ) ] @@ -488,7 +508,7 @@ def _scan_insert_from_select_cols( compiler.stack[-1]["insert_from_select"] = stmt.select - add_select_cols: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]] = [] + add_select_cols: List[_CrudParamElementSQLExpr] = [] if stmt.include_insert_from_select_defaults: col_set = set(cols) for col in stmt.table.columns: @@ -499,7 +519,7 @@ def _scan_insert_from_select_cols( col_key = _getattr_col_key(c) if col_key in parameters and col_key not in check_columns: parameters.pop(col_key) - values.append((c, compiler.preparer.format_column(c), None)) + values.append((c, compiler.preparer.format_column(c), None, ())) else: _append_param_insert_select_hasdefault( compiler, stmt, c, add_select_cols, kw @@ -513,7 +533,7 @@ def _scan_insert_from_select_cols( f"Can't extend statement for INSERT..FROM SELECT to include " f"additional default-holding column(s) " f"""{ - ', '.join(repr(key) for _, key, _ in add_select_cols) + ', '.join(repr(key) for _, key, _, _ in add_select_cols) }. Convert the selectable to a subquery() first, or pass """ "include_defaults=False to Insert.from_select() to skip these " "columns." @@ -521,7 +541,7 @@ def _scan_insert_from_select_cols( ins_from_select = ins_from_select._generate() # copy raw_columns ins_from_select._raw_columns = list(ins_from_select._raw_columns) + [ - expr for _, _, expr in add_select_cols + expr for _, _, expr, _ in add_select_cols ] compiler.stack[-1]["insert_from_select"] = ins_from_select @@ -748,6 +768,8 @@ def _append_param_parameter( c, use_table=compile_state.include_table_with_column_exprs ) + accumulated_bind_names: Set[str] = set() + if coercions._is_literal(value): if ( @@ -772,6 +794,7 @@ def _append_param_parameter( if not _compile_state_isinsert(compile_state) or not compile_state._has_multi_parameters else "%s_m0" % _col_bind_name(c), + accumulate_bind_names=accumulated_bind_names, **kw, ) elif value._is_bind_parameter: @@ -796,11 +819,16 @@ def _append_param_parameter( if not _compile_state_isinsert(compile_state) or not compile_state._has_multi_parameters else "%s_m0" % _col_bind_name(c), + accumulate_bind_names=accumulated_bind_names, **kw, ) else: # value is a SQL expression - value = compiler.process(value.self_group(), **kw) + value = compiler.process( + value.self_group(), + accumulate_bind_names=accumulated_bind_names, + **kw, + ) if compile_state.isupdate: if implicit_return_defaults and c in implicit_return_defaults: @@ -828,7 +856,7 @@ def _append_param_parameter( compiler.postfetch.append(c) - values.append((c, col_value, value)) + values.append((c, col_value, value, accumulated_bind_names)) def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): @@ -843,20 +871,32 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): not c.default.optional or not compiler.dialect.sequences_optional ): + accumulated_bind_names: Set[str] = set() values.append( ( c, compiler.preparer.format_column(c), - compiler.process(c.default, **kw), + compiler.process( + c.default, + accumulate_bind_names=accumulated_bind_names, + **kw, + ), + accumulated_bind_names, ) ) compiler.implicit_returning.append(c) elif c.default.is_clause_element: + accumulated_bind_names = set() values.append( ( c, compiler.preparer.format_column(c), - compiler.process(c.default.arg.self_group(), **kw), + compiler.process( + c.default.arg.self_group(), + accumulate_bind_names=accumulated_bind_names, + **kw, + ), + accumulated_bind_names, ) ) compiler.implicit_returning.append(c) @@ -869,6 +909,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): c, compiler.preparer.format_column(c), _create_insert_prefetch_bind_param(compiler, c, **kw), + (c.key,), ) ) elif c is stmt.table._autoincrement_column or c.server_default is not None: @@ -936,6 +977,7 @@ def _append_param_insert_pk_no_returning(compiler, stmt, c, values, kw): c, compiler.preparer.format_column(c), _create_insert_prefetch_bind_param(compiler, c, **kw), + (c.key,), ) ) elif ( @@ -962,11 +1004,17 @@ def _append_param_insert_hasdefault( if compiler.dialect.supports_sequences and ( not c.default.optional or not compiler.dialect.sequences_optional ): + accumulated_bind_names: Set[str] = set() values.append( ( c, compiler.preparer.format_column(c), - compiler.process(c.default, **kw), + compiler.process( + c.default, + accumulate_bind_names=accumulated_bind_names, + **kw, + ), + accumulated_bind_names, ) ) if implicit_return_defaults and c in implicit_return_defaults: @@ -974,11 +1022,17 @@ def _append_param_insert_hasdefault( elif not c.primary_key: compiler.postfetch.append(c) elif c.default.is_clause_element: + accumulated_bind_names = set() values.append( ( c, compiler.preparer.format_column(c), - compiler.process(c.default.arg.self_group(), **kw), + compiler.process( + c.default.arg.self_group(), + accumulate_bind_names=accumulated_bind_names, + **kw, + ), + accumulated_bind_names, ) ) @@ -993,6 +1047,7 @@ def _append_param_insert_hasdefault( c, compiler.preparer.format_column(c), _create_insert_prefetch_bind_param(compiler, c, **kw), + (c.key,), ) ) @@ -1001,7 +1056,7 @@ def _append_param_insert_select_hasdefault( compiler: SQLCompiler, stmt: ValuesBase, c: ColumnClause[Any], - values: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]], + values: List[_CrudParamElementSQLExpr], kw: Dict[str, Any], ) -> None: @@ -1014,6 +1069,7 @@ def _append_param_insert_select_hasdefault( c, compiler.preparer.format_column(c), c.default.next_value(), + (), ) ) elif default_is_clause_element(c.default): @@ -1022,6 +1078,7 @@ def _append_param_insert_select_hasdefault( c, compiler.preparer.format_column(c), c.default.arg.self_group(), + (), ) ) else: @@ -1032,6 +1089,7 @@ def _append_param_insert_select_hasdefault( _create_insert_prefetch_bind_param( compiler, c, process=False, **kw ), + (c.key,), ) ) @@ -1051,6 +1109,7 @@ def _append_param_update( use_table=include_table, ), compiler.process(c.onupdate.arg.self_group(), **kw), + (), ) ) if implicit_return_defaults and c in implicit_return_defaults: @@ -1066,6 +1125,7 @@ def _append_param_update( use_table=include_table, ), _create_update_prefetch_bind_param(compiler, c, **kw), + (c.key,), ) ) elif c.server_onupdate is not None: @@ -1177,7 +1237,7 @@ class _multiparam_column(elements.ColumnElement[Any]): def _process_multiparam_default_bind( compiler: SQLCompiler, stmt: ValuesBase, - c: ColumnClause[Any], + c: KeyedColumnElement[Any], index: int, kw: Dict[str, Any], ) -> str: @@ -1243,14 +1303,18 @@ def _get_update_multitable_params( name=_col_bind_name(c), **kw, # TODO: no test coverage for literal binds here ) + accumulated_bind_names: Iterable[str] = (c.key,) elif value._is_bind_parameter: + cbn = _col_bind_name(c) value = _handle_values_anonymous_param( - compiler, c, value, name=_col_bind_name(c), **kw + compiler, c, value, name=cbn, **kw ) + accumulated_bind_names = (cbn,) else: compiler.postfetch.append(c) value = compiler.process(value.self_group(), **kw) - values.append((c, col_value, value)) + accumulated_bind_names = () + values.append((c, col_value, value, accumulated_bind_names)) # determine tables which are actually to be updated - process onupdate # and server_onupdate for these for t in affected_tables: @@ -1266,6 +1330,7 @@ def _get_update_multitable_params( compiler.process( c.onupdate.arg.self_group(), **kw ), + (), ) ) compiler.postfetch.append(c) @@ -1277,6 +1342,7 @@ def _get_update_multitable_params( _create_update_prefetch_bind_param( compiler, c, name=_col_bind_name(c), **kw ), + (c.key,), ) ) elif c.server_onupdate is not None: @@ -1287,21 +1353,21 @@ def _extend_values_for_multiparams( compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState, - initial_values: Sequence[Tuple[ColumnClause[Any], str, str]], + initial_values: Sequence[_CrudParamElementStr], _column_as_key: Callable[..., str], kw: Dict[str, Any], -) -> List[Sequence[Tuple[ColumnClause[Any], str, str]]]: +) -> List[Sequence[_CrudParamElementStr]]: values_0 = initial_values values = [initial_values] mp = compile_state._multi_parameters assert mp is not None for i, row in enumerate(mp[1:]): - extension: List[Tuple[ColumnClause[Any], str, str]] = [] + extension: List[_CrudParamElementStr] = [] row = {_column_as_key(key): v for key, v in row.items()} - for (col, col_expr, param) in values_0: + for (col, col_expr, param, accumulated_names) in values_0: if col.key in row: key = col.key @@ -1320,7 +1386,7 @@ def _extend_values_for_multiparams( compiler, stmt, col, i, kw ) - extension.append((col, col_expr, new_param)) + extension.append((col, col_expr, new_param, accumulated_names)) values.append(extension) @@ -1366,7 +1432,8 @@ def _get_stmt_parameter_tuples_params( v = compiler.process(v.self_group(), **kw) - values.append((k, col_expr, v)) + # TODO: not sure if accumulated_bind_names applies here + values.append((k, col_expr, v, ())) def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py index 4ce093156d..429ebf163c 100644 --- a/test/sql/test_insert_exec.py +++ b/test/sql/test_insert_exec.py @@ -1,6 +1,7 @@ import itertools from sqlalchemy import and_ +from sqlalchemy import bindparam from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import ForeignKey @@ -8,6 +9,7 @@ from sqlalchemy import func from sqlalchemy import INT from sqlalchemy import Integer from sqlalchemy import literal +from sqlalchemy import select from sqlalchemy import Sequence from sqlalchemy import sql from sqlalchemy import String @@ -741,6 +743,14 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest): Column("\u6e2c\u8a66", Integer), ) + Table( + "extra_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x_value", String(50)), + Column("y_value", String(50)), + ) + def test_insert_unicode_keys(self, connection): table = self.tables["Unitéble2"] @@ -807,6 +817,88 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest): eq_(result.inserted_primary_key_rows, [(1,), (2,), (3,)]) + @testing.combinations(True, False, argnames="use_returning") + @testing.combinations(1, 2, argnames="num_embedded_params") + @testing.combinations(True, False, argnames="use_whereclause") + @testing.crashes( + "+mariadbconnector", + "returning crashes, regular executemany malfunctions", + ) + def test_insert_w_bindparam_in_subq( + self, connection, use_returning, num_embedded_params, use_whereclause + ): + """test #8639""" + + t = self.tables.data + extra = self.tables.extra_table + + conn = connection + connection.execute( + extra.insert(), + [ + {"x_value": "p1", "y_value": "yv1"}, + {"x_value": "p2", "y_value": "yv2"}, + {"x_value": "p1_p1", "y_value": "yv3"}, + {"x_value": "p2_p2", "y_value": "yv4"}, + ], + ) + + if num_embedded_params == 1: + if use_whereclause: + scalar_subq = select(bindparam("paramname")).scalar_subquery() + params = [ + {"paramname": "p1_p1", "y": "y1"}, + {"paramname": "p2_p2", "y": "y2"}, + ] + else: + scalar_subq = ( + select(extra.c.x_value) + .where(extra.c.y_value == bindparam("y_value")) + .scalar_subquery() + ) + params = [ + {"y_value": "yv3", "y": "y1"}, + {"y_value": "yv4", "y": "y2"}, + ] + + elif num_embedded_params == 2: + if use_whereclause: + scalar_subq = ( + select( + bindparam("paramname1", type_=String) + extra.c.x_value + ) + .where(extra.c.y_value == bindparam("y_value")) + .scalar_subquery() + ) + params = [ + {"paramname1": "p1_", "y_value": "yv1", "y": "y1"}, + {"paramname1": "p2_", "y_value": "yv2", "y": "y2"}, + ] + else: + scalar_subq = select( + bindparam("paramname1", type_=String) + + bindparam("paramname2", type_=String) + ).scalar_subquery() + params = [ + {"paramname1": "p1_", "paramname2": "p1", "y": "y1"}, + {"paramname1": "p2_", "paramname2": "p2", "y": "y2"}, + ] + else: + assert False + + stmt = t.insert().values(x=scalar_subq) + if use_returning: + stmt = stmt.returning(t.c["x", "y"]) + + result = conn.execute(stmt, params) + + if use_returning: + eq_(result.all(), [("p1_p1", "y1"), ("p2_p2", "y2")]) + + result = conn.execute(select(t.c["x", "y"])) + + eq_(result.all(), [("p1_p1", "y1"), ("p2_p2", "y2")]) + def test_insert_returning_defaults(self, connection): t = self.tables.data diff --git a/test/sql/test_type_expressions.py b/test/sql/test_type_expressions.py index 901be7132a..6c717123bd 100644 --- a/test/sql/test_type_expressions.py +++ b/test/sql/test_type_expressions.py @@ -365,6 +365,22 @@ class DerivedTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): class RoundTripTestBase: + @testing.requires.insertmanyvalues + def test_insertmanyvalues_returning(self, connection): + tt = self.tables.test_table + result = connection.execute( + tt.insert().returning(tt.c["x", "y"]), + [ + {"x": "X1", "y": "Y1"}, + {"x": "X2", "y": "Y2"}, + {"x": "X3", "y": "Y3"}, + ], + ) + eq_( + result.all(), + [("X1", "Y1"), ("X2", "Y2"), ("X3", "Y3")], + ) + def test_round_trip(self, connection): connection.execute( self.tables.test_table.insert(),