From c0685e5f419e203acd5e46f25c90e851e30e6f03 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 8 Aug 2020 13:03:17 -0400 Subject: [PATCH] render INSERT/UPDATE column expressions up front; pass state Fixes related to rendering of complex UPDATE DML which was not correctly preserving positional parameter order in conjunction with DML features that are only known to work on the PostgreSQL database. Both pg8000 and asyncpg use positional parameters which is why these issues are suddenly apparent. crud.py now takes on the task of rendering the column expressions for SET or VALUES so that for the very unusual case that the column expression is a compound expression that includes a bound parameter (namely an array index), the bound parameter order is preserved. Additionally, crud.py passes through the positional_names keyword argument into bindparam_string() which is necessary when CTEs are being rendered, as PG supports complex CTE / INSERT / UPDATE scenarios. Change-Id: I7f03920500e19b721636b84594de78a5bfdcbc82 --- lib/sqlalchemy/sql/compiler.py | 19 ++-- lib/sqlalchemy/sql/crud.py | 193 ++++++++++++++++++++++++++------- lib/sqlalchemy/sql/dml.py | 10 +- test/sql/test_update.py | 68 +++++++++++- 4 files changed, 234 insertions(+), 56 deletions(-) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8e273f67c4..542bf58ac5 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -3286,7 +3286,7 @@ class SQLCompiler(Compiled): if crud_params_single or not supports_default_values: text += " (%s)" % ", ".join( - [preparer.format_column(c[0]) for c in crud_params_single] + [expr for c, expr, value in crud_params_single] ) if self.returning or insert_stmt._returning: @@ -3311,12 +3311,15 @@ class SQLCompiler(Compiled): elif compile_state._has_multi_parameters: text += " VALUES %s" % ( ", ".join( - "(%s)" % (", ".join(c[1] for c in crud_param_set)) + "(%s)" + % (", ".join(value for c, expr, value in crud_param_set)) for crud_param_set in crud_params ) ) else: - insert_single_values_expr = ", ".join([c[1] for c in crud_params]) + insert_single_values_expr = ", ".join( + [value for c, expr, value in crud_params] + ) text += " VALUES (%s)" % insert_single_values_expr if toplevel: self.insert_single_values_expr = insert_single_values_expr @@ -3424,15 +3427,7 @@ class SQLCompiler(Compiled): text += table_text text += " SET " - include_table = ( - is_multitable and self.render_table_with_column_in_update_from - ) - text += ", ".join( - c[0]._compiler_dispatch(self, include_table=include_table) - + "=" - + c[1] - for c in crud_params - ) + text += ", ".join(expr + "=" + value for c, expr, value in crud_params) if self.returning or update_stmt._returning: if self.returning_precedes_values: diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 85112f8506..7d0616da71 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -65,7 +65,11 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): # compiled params - return binds for all columns if compiler.column_keys is None and compile_state._no_parameters: return [ - (c, _create_bind_param(compiler, c, None, required=True)) + ( + c, + compiler.preparer.format_column(c), + _create_bind_param(compiler, c, None, required=True), + ) for c in stmt.table.columns ] @@ -90,18 +94,20 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): if stmt_parameters is not None: _get_stmt_parameters_params( - compiler, parameters, stmt_parameters, _column_as_key, values, kw + compiler, + compile_state, + parameters, + stmt_parameters, + _column_as_key, + values, + kw, ) check_columns = {} # special logic that only occurs for multi-table UPDATE # statements - if ( - compile_state.isupdate - and compile_state._extra_froms - and stmt_parameters - ): + if compile_state.isupdate and compile_state.is_multitable: _get_multitable_params( compiler, stmt, @@ -162,7 +168,13 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): # into INSERT (firstcol) VALUES (DEFAULT) which can be turned # into an in-place multi values. This supports # insert_executemany_returning mode :) - values = [(stmt.table.columns[0], "DEFAULT")] + values = [ + ( + stmt.table.columns[0], + compiler.preparer.format_column(stmt.table.columns[0]), + "DEFAULT", + ) + ] return values @@ -286,7 +298,7 @@ def _scan_insert_from_select_cols( col_key = _getattr_col_key(c) if col_key in parameters and col_key not in check_columns: parameters.pop(col_key) - values.append((c, None)) + values.append((c, compiler.preparer.format_column(c), None)) else: _append_param_insert_select_hasdefault( compiler, stmt, c, add_select_cols, kw @@ -297,7 +309,7 @@ def _scan_insert_from_select_cols( compiler._insert_from_select = compiler._insert_from_select._generate() compiler._insert_from_select._raw_columns = tuple( compiler._insert_from_select._raw_columns - ) + tuple(expr for col, expr in add_select_cols) + ) + tuple(expr for col, col_expr, expr in add_select_cols) def _scan_cols( @@ -390,7 +402,13 @@ def _scan_cols( elif compile_state.isupdate: _append_param_update( - compiler, stmt, c, implicit_return_defaults, values, kw + compiler, + compile_state, + stmt, + c, + implicit_return_defaults, + values, + kw, ) @@ -410,6 +428,10 @@ def _append_param_parameter( value = parameters.pop(col_key) + col_value = compiler.preparer.format_column( + c, use_table=compile_state.include_table_with_column_exprs + ) + if coercions._is_literal(value): value = _create_bind_param( compiler, @@ -446,7 +468,7 @@ def _append_param_parameter( if not c.primary_key: compiler.postfetch.append(c) value = compiler.process(value.self_group(), **kw) - values.append((c, value)) + values.append((c, col_value, value)) def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): @@ -472,16 +494,31 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): not c.default.optional or not compiler.dialect.sequences_optional ): - proc = compiler.process(c.default, **kw) - values.append((c, proc)) + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default, **kw), + ) + ) compiler.returning.append(c) elif c.default.is_clause_element: values.append( - (c, compiler.process(c.default.arg.self_group(), **kw)) + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default.arg.self_group(), **kw), + ) ) compiler.returning.append(c) else: - values.append((c, _create_insert_prefetch_bind_param(compiler, c))) + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + ) + ) elif c is stmt.table._autoincrement_column or c.server_default is not None: compiler.returning.append(c) elif not c.nullable: @@ -490,14 +527,22 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): _warn_pk_with_no_anticipated_value(c) -def _create_insert_prefetch_bind_param(compiler, c, process=True, name=None): - param = _create_bind_param(compiler, c, None, process=process, name=name) +def _create_insert_prefetch_bind_param( + compiler, c, process=True, name=None, **kw +): + param = _create_bind_param( + compiler, c, None, process=process, name=name, **kw + ) compiler.insert_prefetch.append(c) return param -def _create_update_prefetch_bind_param(compiler, c, process=True, name=None): - param = _create_bind_param(compiler, c, None, process=process, name=name) +def _create_update_prefetch_bind_param( + compiler, c, process=True, name=None, **kw +): + param = _create_bind_param( + compiler, c, None, process=process, name=name, **kw + ) compiler.update_prefetch.append(c) return param @@ -539,9 +584,9 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw): else: col = _multiparam_column(c, index) if isinstance(stmt, dml.Insert): - return _create_insert_prefetch_bind_param(compiler, col) + return _create_insert_prefetch_bind_param(compiler, col, **kw) else: - return _create_update_prefetch_bind_param(compiler, col) + return _create_update_prefetch_bind_param(compiler, col, **kw) def _append_param_insert_pk(compiler, stmt, c, values, kw): @@ -582,7 +627,13 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw): or compiler.dialect.preexecute_autoincrement_sequences ) ): - values.append((c, _create_insert_prefetch_bind_param(compiler, c))) + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + ) + ) elif c.default is None and c.server_default is None and not c.nullable: # no .default, no .server_default, not autoincrement, we have # no indication this primary key column will have any value @@ -597,15 +648,25 @@ def _append_param_insert_hasdefault( if compiler.dialect.supports_sequences and ( not c.default.optional or not compiler.dialect.sequences_optional ): - proc = compiler.process(c.default, **kw) - values.append((c, proc)) + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default, **kw), + ) + ) if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) elif c.default.is_clause_element: - proc = compiler.process(c.default.arg.self_group(), **kw) - values.append((c, proc)) + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default.arg.self_group(), **kw), + ) + ) if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) @@ -613,7 +674,13 @@ def _append_param_insert_hasdefault( # don't add primary key column to postfetch compiler.postfetch.append(c) else: - values.append((c, _create_insert_prefetch_bind_param(compiler, c))) + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + ) + ) def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw): @@ -622,32 +689,55 @@ def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw): if compiler.dialect.supports_sequences and ( not c.default.optional or not compiler.dialect.sequences_optional ): - proc = c.default - values.append((c, proc.next_value())) + values.append( + (c, compiler.preparer.format_column(c), c.default.next_value()) + ) elif c.default.is_clause_element: - proc = c.default.arg.self_group() - values.append((c, proc)) + values.append( + (c, compiler.preparer.format_column(c), c.default.arg.self_group()) + ) else: values.append( - (c, _create_insert_prefetch_bind_param(compiler, c, process=False)) + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param( + compiler, c, process=False, **kw + ), + ) ) def _append_param_update( - compiler, stmt, c, implicit_return_defaults, values, kw + compiler, compile_state, stmt, c, implicit_return_defaults, values, kw ): + include_table = compile_state.include_table_with_column_exprs if c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, compiler.process(c.onupdate.arg.self_group(), **kw)) + ( + c, + compiler.preparer.format_column( + c, use_table=include_table, + ), + compiler.process(c.onupdate.arg.self_group(), **kw), + ) ) if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) else: compiler.postfetch.append(c) else: - values.append((c, _create_update_prefetch_bind_param(compiler, c))) + values.append( + ( + c, + compiler.preparer.format_column( + c, use_table=include_table, + ), + _create_update_prefetch_bind_param(compiler, c, **kw), + ) + ) elif c.server_onupdate is not None: if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) @@ -676,6 +766,9 @@ def _get_multitable_params( (coercions.expect(roles.DMLColumnRole, c), param) for c, param in stmt_parameters.items() ) + + include_table = compile_state.include_table_with_column_exprs + affected_tables = set() for t in compile_state._extra_froms: for c in t.c: @@ -683,6 +776,8 @@ def _get_multitable_params( affected_tables.add(t) check_columns[_getattr_col_key(c)] = c value = normalized_params[c] + + col_value = compiler.process(c, include_table=include_table) if coercions._is_literal(value): value = _create_bind_param( compiler, @@ -699,7 +794,7 @@ def _get_multitable_params( else: compiler.postfetch.append(c) value = compiler.process(value.self_group(), **kw) - values.append((c, value)) + values.append((c, col_value, value)) # determine tables which are actually to be updated - process onupdate # and server_onupdate for these for t in affected_tables: @@ -711,6 +806,7 @@ def _get_multitable_params( values.append( ( c, + compiler.process(c, include_table=include_table), compiler.process( c.onupdate.arg.self_group(), **kw ), @@ -721,8 +817,9 @@ def _get_multitable_params( values.append( ( c, + compiler.process(c, include_table=include_table), _create_update_prefetch_bind_param( - compiler, c, name=_col_bind_name(c) + compiler, c, name=_col_bind_name(c), **kw ), ) ) @@ -736,7 +833,7 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw): for i, row in enumerate(compile_state._multi_parameters[1:]): extension = [] - for (col, param) in values_0: + for (col, col_expr, param) in values_0: if col in row or col.key in row: key = col if col in row else col.key @@ -755,7 +852,7 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw): compiler, stmt, col, i, kw ) - extension.append((col, new_param)) + extension.append((col, col_expr, new_param)) values.append(extension) @@ -763,8 +860,15 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw): def _get_stmt_parameters_params( - compiler, parameters, stmt_parameters, _column_as_key, values, kw + compiler, + compile_state, + parameters, + stmt_parameters, + _column_as_key, + values, + kw, ): + for k, v in stmt_parameters.items(): colkey = _column_as_key(k) if colkey is not None: @@ -773,6 +877,11 @@ def _get_stmt_parameters_params( # a non-Column expression on the left side; # add it to values() in an "as-is" state, # coercing right side to bound param + + col_expr = compiler.process( + k, include_table=compile_state.include_table_with_column_exprs + ) + if coercions._is_literal(v): v = compiler.process( elements.BindParameter(None, v, type_=k.type), **kw @@ -780,7 +889,7 @@ def _get_stmt_parameters_params( else: v = compiler.process(v.self_group(), **kw) - values.append((k, v)) + values.append((k, col_expr, v)) def _get_returning_modifiers(compiler, stmt, compile_state): diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 21476c1f91..9308658984 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -133,6 +133,8 @@ class DMLState(CompileState): class InsertDMLState(DMLState): isinsert = True + include_table_with_column_exprs = False + def __init__(self, statement, compiler, **kw): self.statement = statement @@ -149,6 +151,8 @@ class InsertDMLState(DMLState): class UpdateDMLState(DMLState): isupdate = True + include_table_with_column_exprs = False + def __init__(self, statement, compiler, **kw): self.statement = statement self.isupdate = True @@ -159,7 +163,11 @@ class UpdateDMLState(DMLState): self._process_values(statement) elif statement._multi_values: self._process_multi_values(statement) - self._extra_froms = self._make_extra_froms(statement) + self._extra_froms = ef = self._make_extra_froms(statement) + self.is_multitable = mt = ef and self._dict_parameters + self.include_table_with_column_exprs = ( + mt and compiler.render_table_with_column_in_update_from + ) @CompileState.plugin_for("default", "delete") diff --git a/test/sql/test_update.py b/test/sql/test_update.py index 18e9da654f..5db5fed11c 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -840,7 +840,7 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): dialect=mysql.dialect(), ) - def test_update_to_expression(self): + def test_update_to_expression_one(self): """test update from an expression. this logic is triggered currently by a left side that doesn't @@ -856,6 +856,72 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): "UPDATE mytable SET foo(myid)=:param_1", ) + def test_update_to_expression_two(self): + """test update from an expression. + + this logic is triggered currently by a left side that doesn't + have a key. The current supported use case is updating the index + of a PostgreSQL ARRAY type. + + """ + + from sqlalchemy import ARRAY + + t = table( + "foo", + column("data1", ARRAY(Integer)), + column("data2", ARRAY(Integer)), + ) + + stmt = t.update().values({t.c.data1[5]: 7, t.c.data2[10]: 18}) + dialect = default.StrCompileDialect() + dialect.paramstyle = "qmark" + dialect.positional = True + self.assert_compile( + stmt, + "UPDATE foo SET data1[?]=?, data2[?]=?", + dialect=dialect, + checkpositional=(5, 7, 10, 18), + ) + + def test_update_to_expression_three(self): + # this test is from test_defaults but exercises a particular + # parameter ordering issue + metadata = MetaData() + + q = Table( + "q", + metadata, + Column("x", Integer, default=2), + Column("y", Integer, onupdate=5), + Column("z", Integer), + ) + + p = Table( + "p", + metadata, + Column("s", Integer), + Column("t", Integer), + Column("u", Integer, onupdate=1), + ) + + cte = ( + q.update().where(q.c.z == 1).values(x=7).returning(q.c.z).cte("c") + ) + stmt = select([p.c.s, cte.c.z]).where(p.c.s == cte.c.z) + + dialect = default.StrCompileDialect() + dialect.paramstyle = "qmark" + dialect.positional = True + + self.assert_compile( + stmt, + "WITH c AS (UPDATE q SET x=?, y=? WHERE q.z = ? RETURNING q.z) " + "SELECT p.s, c.z FROM p, c WHERE p.s = c.z", + checkpositional=(7, None, 1), + dialect=dialect, + ) + def test_update_bound_ordering(self): """test that bound parameters between the UPDATE and FROM clauses order correctly in different SQL compilation scenarios. -- 2.39.5