]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Sweep through UPDATE ordered_values a second time
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 13 Aug 2020 00:14:15 +0000 (20:14 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 13 Aug 2020 02:00:43 +0000 (22:00 -0400)
The fix in 180ae7c1a53385f72b0047496ac001ec5099cc3e
didn't do much as the code was not preserving parameter
order at all, in fact.    Reworked stmt_parameters to be
delivered in the correct order up front and preserve
throughout crud.py which was not being done at all
before.

Fixes: #5510
Change-Id: I0795c71df73005a25d1bbf216732d41b41e11a5f

lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/sql/dml.py
test/sql/test_update.py

index b049b2f3372a68ea16ea960c858ff990a4977417..3bf8a7c624e4d91b79eaadc1467d93f97aa8e5f1 100644 (file)
@@ -74,30 +74,41 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
         ]
 
     if compile_state._has_multi_parameters:
-        stmt_parameters = compile_state._multi_parameters[0]
+        spd = compile_state._multi_parameters[0]
+        stmt_parameter_tuples = list(spd.items())
+    elif compile_state._ordered_values:
+        spd = compile_state._dict_parameters
+        stmt_parameter_tuples = compile_state._ordered_values
+    elif compile_state._dict_parameters:
+        spd = compile_state._dict_parameters
+        stmt_parameter_tuples = list(spd.items())
     else:
-        stmt_parameters = compile_state._dict_parameters
+        stmt_parameter_tuples = spd = None
 
     # if we have statement parameters - set defaults in the
     # compiled params
     if compiler.column_keys is None:
         parameters = {}
-    else:
+    elif stmt_parameter_tuples:
         parameters = dict(
             (_column_as_key(key), REQUIRED)
             for key in compiler.column_keys
-            if not stmt_parameters or key not in stmt_parameters
+            if key not in spd
+        )
+    else:
+        parameters = dict(
+            (_column_as_key(key), REQUIRED) for key in compiler.column_keys
         )
 
     # create a list of column assignment clauses as tuples
     values = []
 
-    if stmt_parameters is not None:
-        _get_stmt_parameters_params(
+    if stmt_parameter_tuples is not None:
+        _get_stmt_parameter_tuples_params(
             compiler,
             compile_state,
             parameters,
-            stmt_parameters,
+            stmt_parameter_tuples,
             _column_as_key,
             values,
             kw,
@@ -112,7 +123,7 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
             compiler,
             stmt,
             compile_state,
-            stmt_parameters,
+            stmt_parameter_tuples,
             check_columns,
             _col_bind_name,
             _getattr_col_key,
@@ -147,10 +158,10 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
             kw,
         )
 
-    if parameters and stmt_parameters:
+    if parameters and stmt_parameter_tuples:
         check = (
             set(parameters)
-            .intersection(_column_as_key(k) for k in stmt_parameters)
+            .intersection(_column_as_key(k) for k, v in stmt_parameter_tuples)
             .difference(check_columns)
         )
         if check:
@@ -342,6 +353,7 @@ def _scan_cols(
             for key in parameter_ordering
             if isinstance(key, util.string_types) and key in stmt.table.c
         ] + [c for c in stmt.table.c if c.key not in ordered_keys]
+
     else:
         cols = stmt.table.columns
 
@@ -757,7 +769,7 @@ def _get_multitable_params(
     compiler,
     stmt,
     compile_state,
-    stmt_parameters,
+    stmt_parameter_tuples,
     check_columns,
     _col_bind_name,
     _getattr_col_key,
@@ -766,7 +778,7 @@ def _get_multitable_params(
 ):
     normalized_params = dict(
         (coercions.expect(roles.DMLColumnRole, c), param)
-        for c, param in stmt_parameters.items()
+        for c, param in stmt_parameter_tuples
     )
 
     include_table = compile_state.include_table_with_column_exprs
@@ -861,17 +873,17 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw):
     return values
 
 
-def _get_stmt_parameters_params(
+def _get_stmt_parameter_tuples_params(
     compiler,
     compile_state,
     parameters,
-    stmt_parameters,
+    stmt_parameter_tuples,
     _column_as_key,
     values,
     kw,
 ):
 
-    for k, v in stmt_parameters.items():
+    for k, v in stmt_parameter_tuples:
         colkey = _column_as_key(k)
         if colkey is not None:
             parameters.setdefault(colkey, v)
index 9308658984ba784308ea082135b09d52982f2183..a9bccaeff88cfac3efd621fef8d0ae521ec6225e 100644 (file)
@@ -34,6 +34,7 @@ class DMLState(CompileState):
     _no_parameters = True
     _dict_parameters = None
     _multi_parameters = None
+    _ordered_values = None
     _parameter_ordering = None
     _has_multi_parameters = False
     isupdate = False
@@ -97,6 +98,7 @@ class DMLState(CompileState):
         if self._no_parameters:
             self._no_parameters = False
             self._dict_parameters = dict(parameters)
+            self._ordered_values = parameters
             self._parameter_ordering = [key for key, value in parameters]
         elif self._has_multi_parameters:
             self._cant_mix_formats_error()
index 1eed378326bd418d8fe21b0b86c77ea74fb3a233..964e3ee6be377a85eddadb89bc3a7b61e1c1af13 100644 (file)
@@ -1,3 +1,6 @@
+import itertools
+import random
+
 from sqlalchemy import bindparam
 from sqlalchemy import column
 from sqlalchemy import exc
@@ -23,6 +26,7 @@ from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import mock
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 
@@ -856,34 +860,125 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             "UPDATE mytable SET foo(myid)=:param_1",
         )
 
-    def test_update_to_expression_two(self):
-        """test update from an expression.
+    @testing.fixture
+    def randomized_param_order_update(self):
+        from sqlalchemy.sql.dml import UpdateDMLState
 
-        this logic is triggered currently by a left side that doesn't
-        have a key.  The current supported use case is updating the index
-        of a PostgreSQL ARRAY type.
+        super_process_ordered_values = UpdateDMLState._process_ordered_values
 
-        """
+        # this fixture is needed for Python 3.6 and above to work around
+        # dictionaries being insert-ordered.  in python 2.7 the previous
+        # logic fails pretty easily without this fixture.
+        def _process_ordered_values(self, statement):
+            super_process_ordered_values(self, statement)
+
+            tuples = list(self._dict_parameters.items())
+            random.shuffle(tuples)
+            self._dict_parameters = dict(tuples)
+
+        dialect = default.StrCompileDialect()
+        dialect.paramstyle = "qmark"
+        dialect.positional = True
+
+        with mock.patch.object(
+            UpdateDMLState, "_process_ordered_values", _process_ordered_values
+        ):
+            yield
 
+    def random_update_order_parameters():
         from sqlalchemy import ARRAY
 
         t = table(
             "foo",
             column("data1", ARRAY(Integer)),
             column("data2", ARRAY(Integer)),
+            column("data3", ARRAY(Integer)),
+            column("data4", ARRAY(Integer)),
         )
 
+        idx_to_value = [
+            (t.c.data1, 5, 7),
+            (t.c.data2, 10, 18),
+            (t.c.data3, 8, 4),
+            (t.c.data4, 12, 14),
+        ]
+
+        def combinations():
+            while True:
+                random.shuffle(idx_to_value)
+                yield list(idx_to_value)
+
+        return testing.combinations(
+            *[
+                (t, combination)
+                for i, combination in zip(range(10), combinations())
+            ],
+            argnames="t, idx_to_value"
+        )
+
+    @random_update_order_parameters()
+    def test_update_to_expression_two(
+        self, randomized_param_order_update, t, idx_to_value
+    ):
+        """test update from an expression.
+
+        this logic is triggered currently by a left side that doesn't
+        have a key.  The current supported use case is updating the index
+        of a PostgreSQL ARRAY type.
+
+        """
+
+        dialect = default.StrCompileDialect()
+        dialect.paramstyle = "qmark"
+        dialect.positional = True
+
         stmt = t.update().ordered_values(
-            (t.c.data1[5], 7), (t.c.data2[10], 18)
+            *[(col[idx], val) for col, idx, val in idx_to_value]
         )
+
+        self.assert_compile(
+            stmt,
+            "UPDATE foo SET %s"
+            % (
+                ", ".join(
+                    "%s[?]=?" % col.key for col, idx, val in idx_to_value
+                )
+            ),
+            dialect=dialect,
+            checkpositional=tuple(
+                itertools.chain.from_iterable(
+                    (idx, val) for col, idx, val in idx_to_value
+                )
+            ),
+        )
+
+    @random_update_order_parameters()
+    def test_update_to_expression_ppo(
+        self, randomized_param_order_update, t, idx_to_value
+    ):
         dialect = default.StrCompileDialect()
         dialect.paramstyle = "qmark"
         dialect.positional = True
+
+        # deprecated pattern here
+        stmt = t.update(preserve_parameter_order=True).values(
+            [(col[idx], val) for col, idx, val in idx_to_value]
+        )
+
         self.assert_compile(
             stmt,
-            "UPDATE foo SET data1[?]=?, data2[?]=?",
+            "UPDATE foo SET %s"
+            % (
+                ", ".join(
+                    "%s[?]=?" % col.key for col, idx, val in idx_to_value
+                )
+            ),
             dialect=dialect,
-            checkpositional=(5, 7, 10, 18),
+            checkpositional=tuple(
+                itertools.chain.from_iterable(
+                    (idx, val) for col, idx, val in idx_to_value
+                )
+            ),
         )
 
     def test_update_to_expression_three(self):