From: Mike Bayer Date: Thu, 13 Aug 2020 00:14:15 +0000 (-0400) Subject: Sweep through UPDATE ordered_values a second time X-Git-Tag: rel_1_4_0b1~181 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=65da69910944ccbad0c6d008b94ae8271aae4762;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Sweep through UPDATE ordered_values a second time The fix in 180ae7c1a53385f72b0047496ac001ec5099cc3e didn't do much as the code was not preserving parameter order at all, in fact. Reworked stmt_parameters to be delivered in the correct order up front and preserve throughout crud.py which was not being done at all before. Fixes: #5510 Change-Id: I0795c71df73005a25d1bbf216732d41b41e11a5f --- diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index b049b2f337..3bf8a7c624 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -74,30 +74,41 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): ] if compile_state._has_multi_parameters: - stmt_parameters = compile_state._multi_parameters[0] + spd = compile_state._multi_parameters[0] + stmt_parameter_tuples = list(spd.items()) + elif compile_state._ordered_values: + spd = compile_state._dict_parameters + stmt_parameter_tuples = compile_state._ordered_values + elif compile_state._dict_parameters: + spd = compile_state._dict_parameters + stmt_parameter_tuples = list(spd.items()) else: - stmt_parameters = compile_state._dict_parameters + stmt_parameter_tuples = spd = None # if we have statement parameters - set defaults in the # compiled params if compiler.column_keys is None: parameters = {} - else: + elif stmt_parameter_tuples: parameters = dict( (_column_as_key(key), REQUIRED) for key in compiler.column_keys - if not stmt_parameters or key not in stmt_parameters + if key not in spd + ) + else: + parameters = dict( + (_column_as_key(key), REQUIRED) for key in compiler.column_keys ) # create a list of column assignment clauses as tuples values = [] - if stmt_parameters is not None: - _get_stmt_parameters_params( + if stmt_parameter_tuples is not None: + _get_stmt_parameter_tuples_params( compiler, compile_state, parameters, - stmt_parameters, + stmt_parameter_tuples, _column_as_key, values, kw, @@ -112,7 +123,7 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): compiler, stmt, compile_state, - stmt_parameters, + stmt_parameter_tuples, check_columns, _col_bind_name, _getattr_col_key, @@ -147,10 +158,10 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): kw, ) - if parameters and stmt_parameters: + if parameters and stmt_parameter_tuples: check = ( set(parameters) - .intersection(_column_as_key(k) for k in stmt_parameters) + .intersection(_column_as_key(k) for k, v in stmt_parameter_tuples) .difference(check_columns) ) if check: @@ -342,6 +353,7 @@ def _scan_cols( for key in parameter_ordering if isinstance(key, util.string_types) and key in stmt.table.c ] + [c for c in stmt.table.c if c.key not in ordered_keys] + else: cols = stmt.table.columns @@ -757,7 +769,7 @@ def _get_multitable_params( compiler, stmt, compile_state, - stmt_parameters, + stmt_parameter_tuples, check_columns, _col_bind_name, _getattr_col_key, @@ -766,7 +778,7 @@ def _get_multitable_params( ): normalized_params = dict( (coercions.expect(roles.DMLColumnRole, c), param) - for c, param in stmt_parameters.items() + for c, param in stmt_parameter_tuples ) include_table = compile_state.include_table_with_column_exprs @@ -861,17 +873,17 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw): return values -def _get_stmt_parameters_params( +def _get_stmt_parameter_tuples_params( compiler, compile_state, parameters, - stmt_parameters, + stmt_parameter_tuples, _column_as_key, values, kw, ): - for k, v in stmt_parameters.items(): + for k, v in stmt_parameter_tuples: colkey = _column_as_key(k) if colkey is not None: parameters.setdefault(colkey, v) diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 9308658984..a9bccaeff8 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -34,6 +34,7 @@ class DMLState(CompileState): _no_parameters = True _dict_parameters = None _multi_parameters = None + _ordered_values = None _parameter_ordering = None _has_multi_parameters = False isupdate = False @@ -97,6 +98,7 @@ class DMLState(CompileState): if self._no_parameters: self._no_parameters = False self._dict_parameters = dict(parameters) + self._ordered_values = parameters self._parameter_ordering = [key for key, value in parameters] elif self._has_multi_parameters: self._cant_mix_formats_error() diff --git a/test/sql/test_update.py b/test/sql/test_update.py index 1eed378326..964e3ee6be 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -1,3 +1,6 @@ +import itertools +import random + from sqlalchemy import bindparam from sqlalchemy import column from sqlalchemy import exc @@ -23,6 +26,7 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing import mock from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -856,34 +860,125 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): "UPDATE mytable SET foo(myid)=:param_1", ) - def test_update_to_expression_two(self): - """test update from an expression. + @testing.fixture + def randomized_param_order_update(self): + from sqlalchemy.sql.dml import UpdateDMLState - this logic is triggered currently by a left side that doesn't - have a key. The current supported use case is updating the index - of a PostgreSQL ARRAY type. + super_process_ordered_values = UpdateDMLState._process_ordered_values - """ + # this fixture is needed for Python 3.6 and above to work around + # dictionaries being insert-ordered. in python 2.7 the previous + # logic fails pretty easily without this fixture. + def _process_ordered_values(self, statement): + super_process_ordered_values(self, statement) + + tuples = list(self._dict_parameters.items()) + random.shuffle(tuples) + self._dict_parameters = dict(tuples) + + dialect = default.StrCompileDialect() + dialect.paramstyle = "qmark" + dialect.positional = True + + with mock.patch.object( + UpdateDMLState, "_process_ordered_values", _process_ordered_values + ): + yield + def random_update_order_parameters(): from sqlalchemy import ARRAY t = table( "foo", column("data1", ARRAY(Integer)), column("data2", ARRAY(Integer)), + column("data3", ARRAY(Integer)), + column("data4", ARRAY(Integer)), ) + idx_to_value = [ + (t.c.data1, 5, 7), + (t.c.data2, 10, 18), + (t.c.data3, 8, 4), + (t.c.data4, 12, 14), + ] + + def combinations(): + while True: + random.shuffle(idx_to_value) + yield list(idx_to_value) + + return testing.combinations( + *[ + (t, combination) + for i, combination in zip(range(10), combinations()) + ], + argnames="t, idx_to_value" + ) + + @random_update_order_parameters() + def test_update_to_expression_two( + self, randomized_param_order_update, t, idx_to_value + ): + """test update from an expression. + + this logic is triggered currently by a left side that doesn't + have a key. The current supported use case is updating the index + of a PostgreSQL ARRAY type. + + """ + + dialect = default.StrCompileDialect() + dialect.paramstyle = "qmark" + dialect.positional = True + stmt = t.update().ordered_values( - (t.c.data1[5], 7), (t.c.data2[10], 18) + *[(col[idx], val) for col, idx, val in idx_to_value] ) + + self.assert_compile( + stmt, + "UPDATE foo SET %s" + % ( + ", ".join( + "%s[?]=?" % col.key for col, idx, val in idx_to_value + ) + ), + dialect=dialect, + checkpositional=tuple( + itertools.chain.from_iterable( + (idx, val) for col, idx, val in idx_to_value + ) + ), + ) + + @random_update_order_parameters() + def test_update_to_expression_ppo( + self, randomized_param_order_update, t, idx_to_value + ): dialect = default.StrCompileDialect() dialect.paramstyle = "qmark" dialect.positional = True + + # deprecated pattern here + stmt = t.update(preserve_parameter_order=True).values( + [(col[idx], val) for col, idx, val in idx_to_value] + ) + self.assert_compile( stmt, - "UPDATE foo SET data1[?]=?, data2[?]=?", + "UPDATE foo SET %s" + % ( + ", ".join( + "%s[?]=?" % col.key for col, idx, val in idx_to_value + ) + ), dialect=dialect, - checkpositional=(5, 7, 10, 18), + checkpositional=tuple( + itertools.chain.from_iterable( + (idx, val) for col, idx, val in idx_to_value + ) + ), ) def test_update_to_expression_three(self):