]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
simplify internal storage of DML ordered values
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 8 Apr 2025 01:41:29 +0000 (21:41 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 10 Apr 2025 14:28:56 +0000 (10:28 -0400)
towards some refactorings I will need to do for #12496, this
factors out the "_ordered_values" list of tuples that was used to
track UPDATE VALUES in a specific order.   The rationale for this
separate collection was due to Python dictionaries not maintaining
insert order.   Now that this is standard behavior in Python 3
we can use the same `statement._values` for param-ordered and
table-column-ordered UPDATE rendering.

Change-Id: Id6024ab06e5e3ba427174e7ba3630ff83d81f603

lib/sqlalchemy/orm/bulk_persistence.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/sql/dml.py
test/orm/dml/test_update_delete_where.py
test/sql/test_update.py

index ce2efcebce7bbd84b62b39978c81b32b6e4564b7..2664c9f9798db1dfdc4b2c09b8b1a92670fc8e93 100644 (file)
@@ -1046,8 +1046,6 @@ class _BulkUDCompileState(_ORMDMLState):
     def _get_resolved_values(cls, mapper, statement):
         if statement._multi_values:
             return []
-        elif statement._ordered_values:
-            return list(statement._ordered_values)
         elif statement._values:
             return list(statement._values.items())
         else:
@@ -1468,9 +1466,7 @@ class _BulkORMUpdate(_BulkUDCompileState, UpdateDMLState):
         # are passed through to the new statement, which will then raise
         # InvalidRequestError because UPDATE doesn't support multi_values
         # right now.
-        if statement._ordered_values:
-            new_stmt._ordered_values = self._resolved_values
-        elif statement._values:
+        if statement._values:
             new_stmt._values = self._resolved_values
 
         new_crit = self._adjust_for_extra_criteria(
@@ -1557,7 +1553,7 @@ class _BulkORMUpdate(_BulkUDCompileState, UpdateDMLState):
 
         UpdateDMLState.__init__(self, statement, compiler, **kw)
 
-        if self._ordered_values:
+        if self._maintain_values_ordering:
             raise sa_exc.InvalidRequestError(
                 "bulk ORM UPDATE does not support ordered_values() for "
                 "custom UPDATE statements with bulk parameter sets.  Use a "
index d2f2b2b8f0afbff4c34af61ef93fe9d3af2fefc7..1d6b4abf665af6bf6fc55a40718b8cedfb7012f2 100644 (file)
@@ -456,8 +456,13 @@ def _collect_update_commands(
 
         pks = mapper._pks_by_table[table]
 
-        if use_orm_update_stmt is not None:
+        if (
+            use_orm_update_stmt is not None
+            and not use_orm_update_stmt._maintain_values_ordering
+        ):
             # TODO: ordered values, etc
+            # ORM bulk_persistence will raise for the maintain_values_ordering
+            # case right now
             value_params = use_orm_update_stmt._values
         else:
             value_params = {}
index c0c0c86bb9c6d808dd2bd382ce2a605aae31fe82..ca7448b58b76baef89d5af22fe28e1d513044321 100644 (file)
@@ -231,11 +231,6 @@ def _get_crud_params(
         spd = mp[0]
         stmt_parameter_tuples = list(spd.items())
         spd_str_key = {_column_as_key(key) for key in spd}
-    elif compile_state._ordered_values:
-        spd = compile_state._dict_parameters
-        stmt_parameter_tuples = compile_state._ordered_values
-        assert spd is not None
-        spd_str_key = {_column_as_key(key) for key in spd}
     elif compile_state._dict_parameters:
         spd = compile_state._dict_parameters
         stmt_parameter_tuples = list(spd.items())
@@ -617,9 +612,9 @@ def _scan_cols(
 
     assert compile_state.isupdate or compile_state.isinsert
 
-    if compile_state._parameter_ordering:
+    if compile_state._maintain_values_ordering:
         parameter_ordering = [
-            _column_as_key(key) for key in compile_state._parameter_ordering
+            _column_as_key(key) for key in compile_state._dict_parameters
         ]
         ordered_keys = set(parameter_ordering)
         cols = [
index 589f4f3504d65d7583392cf63d79c12da12d7b42..73e61de65d9a27a91dc3263e278b9d05de4e46e3 100644 (file)
@@ -124,8 +124,7 @@ class DMLState(CompileState):
     _multi_parameters: Optional[
         List[MutableMapping[_DMLColumnElement, Any]]
     ] = None
-    _ordered_values: Optional[List[Tuple[_DMLColumnElement, Any]]] = None
-    _parameter_ordering: Optional[List[_DMLColumnElement]] = None
+    _maintain_values_ordering: bool = False
     _primary_table: FromClause
     _supports_implicit_returning = True
 
@@ -348,7 +347,7 @@ class UpdateDMLState(DMLState):
         self.statement = statement
 
         self.isupdate = True
-        if statement._ordered_values is not None:
+        if statement._maintain_values_ordering:
             self._process_ordered_values(statement)
         elif statement._values is not None:
             self._process_values(statement)
@@ -364,14 +363,12 @@ class UpdateDMLState(DMLState):
         )
 
     def _process_ordered_values(self, statement: ValuesBase) -> None:
-        parameters = statement._ordered_values
-
+        parameters = statement._values
         if self._no_parameters:
             self._no_parameters = False
             assert parameters is not None
             self._dict_parameters = dict(parameters)
-            self._ordered_values = parameters
-            self._parameter_ordering = [key for key, value in parameters]
+            self._maintain_values_ordering = True
         else:
             raise exc.InvalidRequestError(
                 "Can only invoke ordered_values() once, and not mixed "
@@ -1003,7 +1000,7 @@ class ValuesBase(UpdateBase):
         ...,
     ] = ()
 
-    _ordered_values: Optional[List[Tuple[_DMLColumnElement, Any]]] = None
+    _maintain_values_ordering: bool = False
 
     _select_names: Optional[List[str]] = None
     _inline: bool = False
@@ -1016,12 +1013,13 @@ class ValuesBase(UpdateBase):
     @_generative
     @_exclusive_against(
         "_select_names",
-        "_ordered_values",
+        "_maintain_values_ordering",
         msgs={
             "_select_names": "This construct already inserts from a SELECT",
-            "_ordered_values": "This statement already has ordered "
+            "_maintain_values_ordering": "This statement already has ordered "
             "values present",
         },
+        defaults={"_maintain_values_ordering": False},
     )
     def values(
         self,
@@ -1590,7 +1588,7 @@ class Update(
             ("table", InternalTraversal.dp_clauseelement),
             ("_where_criteria", InternalTraversal.dp_clauseelement_tuple),
             ("_inline", InternalTraversal.dp_boolean),
-            ("_ordered_values", InternalTraversal.dp_dml_ordered_values),
+            ("_maintain_values_ordering", InternalTraversal.dp_boolean),
             ("_values", InternalTraversal.dp_dml_values),
             ("_returning", InternalTraversal.dp_clauseelement_tuple),
             ("_hints", InternalTraversal.dp_table_hint_list),
@@ -1614,7 +1612,6 @@ class Update(
     def __init__(self, table: _DMLTableArgument):
         super().__init__(table)
 
-    @_generative
     def ordered_values(self, *args: Tuple[_DMLColumnArgument, Any]) -> Self:
         """Specify the VALUES clause of this UPDATE statement with an explicit
         parameter ordering that will be maintained in the SET clause of the
@@ -1638,15 +1635,13 @@ class Update(
         """  # noqa: E501
         if self._values:
             raise exc.ArgumentError(
-                "This statement already has values present"
-            )
-        elif self._ordered_values:
-            raise exc.ArgumentError(
-                "This statement already has ordered values present"
+                "This statement already has "
+                f"{'ordered ' if self._maintain_values_ordering else ''}"
+                "values present"
             )
 
-        kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
-        self._ordered_values = kv_generator(self, args, True)
+        self = self.values(dict(args))
+        self._maintain_values_ordering = True
         return self
 
     @_generative
index 387ce161b867c5080e0734fd354716a0462fbd03..88a0549a8e3254f9d8287a919c3801051cd86b0a 100644 (file)
@@ -2023,10 +2023,10 @@ class UpdateDeleteTest(fixtures.MappedTest):
         def do_orm_execute(bulk_ud):
             cols = [
                 c.key
-                for c, v in (
+                for c in (
                     (
                         bulk_ud.result.context
-                    ).compiled.compile_state.statement._ordered_values
+                    ).compiled.compile_state.statement._values
                 )
             ]
             m1(cols)
@@ -2081,10 +2081,8 @@ class UpdateDeleteTest(fixtures.MappedTest):
         result = session.execute(stmt)
         cols = [
             c.key
-            for c, v in (
-                (
-                    result.context
-                ).compiled.compile_state.statement._ordered_values
+            for c in (
+                (result.context).compiled.compile_state.statement._values
             )
         ]
         eq_(["age_int", "name"], cols)
@@ -2102,9 +2100,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
         result = session.execute(stmt)
         cols = [
             c.key
-            for c, v in (
-                result.context
-            ).compiled.compile_state.statement._ordered_values
+            for c in (result.context).compiled.compile_state.statement._values
         ]
         eq_(["name", "age_int"], cols)
 
index febbf4345e999dceade507bf81f60acf0763a24b..b381cb010e8695ebf0cca050a2ffbee000c69503 100644 (file)
@@ -27,7 +27,6 @@ from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
-from sqlalchemy.testing import mock
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 
@@ -833,31 +832,6 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             "UPDATE mytable SET foo(myid)=:param_1",
         )
 
-    @testing.fixture
-    def randomized_param_order_update(self):
-        from sqlalchemy.sql.dml import UpdateDMLState
-
-        super_process_ordered_values = UpdateDMLState._process_ordered_values
-
-        # this fixture is needed for Python 3.6 and above to work around
-        # dictionaries being insert-ordered.  in python 2.7 the previous
-        # logic fails pretty easily without this fixture.
-        def _process_ordered_values(self, statement):
-            super_process_ordered_values(self, statement)
-
-            tuples = list(self._dict_parameters.items())
-            random.shuffle(tuples)
-            self._dict_parameters = dict(tuples)
-
-        dialect = default.StrCompileDialect()
-        dialect.paramstyle = "qmark"
-        dialect.positional = True
-
-        with mock.patch.object(
-            UpdateDMLState, "_process_ordered_values", _process_ordered_values
-        ):
-            yield
-
     def random_update_order_parameters():
         from sqlalchemy import ARRAY
 
@@ -890,9 +864,7 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
         )
 
     @random_update_order_parameters()
-    def test_update_to_expression_two(
-        self, randomized_param_order_update, t, idx_to_value
-    ):
+    def test_update_to_expression_two(self, t, idx_to_value):
         """test update from an expression.
 
         this logic is triggered currently by a left side that doesn't