]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- UPDATE statements can now be batched within an ORM flush
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Aug 2014 00:47:49 +0000 (20:47 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Aug 2014 00:47:49 +0000 (20:47 -0400)
into more performant executemany() call, similarly to how INSERT
statements can be batched; this will be invoked within flush
to the degree that subsequent UPDATE statements for the
same mapping and table involve the identical columns within the
VALUES clause, as well as that no VALUES-level SQL expressions
are embedded.
- some other inlinings within persistence.py

doc/build/changelog/changelog_10.rst
lib/sqlalchemy/orm/persistence.py
test/orm/test_unitofwork.py
test/orm/test_unitofworkv2.py

index fb14279ac1ac0e4d32a0b0a83a15efd939abdf3f..439d02c4799057170da5af1ec44fc5e5bd1225f3 100644 (file)
 .. changelog::
        :version: 1.0.0
 
+    .. change::
+        :tags: orm, feature
+
+        UPDATE statements can now be batched within an ORM flush
+        into more performant executemany() call, similarly to how INSERT
+        statements can be batched; this will be invoked within flush
+        to the degree that subsequent UPDATE statements for the
+        same mapping and table involve the identical columns within the
+        VALUES clause, as well as that no VALUES-level SQL expressions
+        are embedded.
+
     .. change::
         :tags: engine, bug
         :tickets: 3163
index 17ce2e6247fd0684b2f9d1c816faef64a6f5a634..9d39c39b07fba2baea0029d2180f57d6f24d56e4 100644 (file)
@@ -248,9 +248,10 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
 
         has_all_pks = True
         has_all_defaults = True
+        has_version_id_generator = mapper.version_id_generator is not False \
+            and mapper.version_id_col is not None
         for col in mapper._cols_by_table[table]:
-            if col is mapper.version_id_col and \
-                    mapper.version_id_generator is not False:
+            if has_version_id_generator and col is mapper.version_id_col:
                 val = mapper.version_id_generator(None)
                 params[col.key] = val
             else:
@@ -305,6 +306,7 @@ def _collect_update_commands(base_mapper, uowtransaction,
         value_params = {}
 
         hasdata = hasnull = False
+
         for col in mapper._cols_by_table[table]:
             if col is mapper.version_id_col:
                 params[col._label] = \
@@ -341,6 +343,7 @@ def _collect_update_commands(base_mapper, uowtransaction,
                 prop = mapper._columntoproperty[col]
                 history = state.manager[prop.key].impl.get_history(
                     state, state_dict,
+                    attributes.PASSIVE_OFF if col in pks else
                     attributes.PASSIVE_NO_INITIALIZE)
                 if history.added:
                     if isinstance(history.added[0],
@@ -381,8 +384,7 @@ def _collect_update_commands(base_mapper, uowtransaction,
                     else:
                         hasdata = True
                 elif col in pks:
-                    value = state.manager[prop.key].impl.get(
-                        state, state_dict)
+                    value = history.unchanged[0]
                     if value is None:
                         hasnull = True
                     params[col._label] = value
@@ -500,41 +502,63 @@ def _emit_update_statements(base_mapper, uowtransaction,
 
     statement = base_mapper._memo(('update', table), update_stmt)
 
-    rows = 0
-    for state, state_dict, params, mapper, \
-            connection, value_params in update:
-
-        if value_params:
-            c = connection.execute(
-                statement.values(value_params),
-                params)
+    for (connection, paramkeys, hasvalue), \
+        records in groupby(
+            update,
+            lambda rec: (
+                rec[4],
+                tuple(sorted(rec[2])),
+                bool(rec[5]))
+            ):
+
+        rows = 0
+        records = list(records)
+        if hasvalue:
+            for state, state_dict, params, mapper, \
+                    connection, value_params in records:
+                c = connection.execute(
+                    statement.values(value_params),
+                    params)
+                _postfetch(
+                    mapper,
+                    uowtransaction,
+                    table,
+                    state,
+                    state_dict,
+                    c,
+                    c.context.compiled_parameters[0],
+                    value_params)
+                rows += c.rowcount
         else:
+            multiparams = [rec[2] for rec in records]
             c = cached_connections[connection].\
-                execute(statement, params)
-
-        _postfetch(
-            mapper,
-            uowtransaction,
-            table,
-            state,
-            state_dict,
-            c,
-            c.context.compiled_parameters[0],
-            value_params)
-        rows += c.rowcount
-
-    if connection.dialect.supports_sane_rowcount:
-        if rows != len(update):
-            raise orm_exc.StaleDataError(
-                "UPDATE statement on table '%s' expected to "
-                "update %d row(s); %d were matched." %
-                (table.description, len(update), rows))
-
-    elif needs_version_id:
-        util.warn("Dialect %s does not support updated rowcount "
-                  "- versioning cannot be verified." %
-                  c.dialect.dialect_description,
-                  stacklevel=12)
+                execute(statement, multiparams)
+
+            rows += c.rowcount
+            for state, state_dict, params, mapper, \
+                    connection, value_params in records:
+                _postfetch(
+                    mapper,
+                    uowtransaction,
+                    table,
+                    state,
+                    state_dict,
+                    c,
+                    c.context.compiled_parameters[0],
+                    value_params)
+
+        if connection.dialect.supports_sane_rowcount:
+            if rows != len(records):
+                raise orm_exc.StaleDataError(
+                    "UPDATE statement on table '%s' expected to "
+                    "update %d row(s); %d were matched." %
+                    (table.description, len(records), rows))
+
+        elif needs_version_id:
+            util.warn("Dialect %s does not support updated rowcount "
+                      "- versioning cannot be verified." %
+                      c.dialect.dialect_description,
+                      stacklevel=12)
 
 
 def _emit_insert_statements(base_mapper, uowtransaction,
@@ -833,15 +857,12 @@ def _connections_for_states(base_mapper, uowtransaction, states):
         connection_callable = \
             uowtransaction.session.connection_callable
     else:
-        connection = None
+        connection = uowtransaction.transaction.connection(base_mapper)
         connection_callable = None
 
     for state in _sort_states(states):
         if connection_callable:
             connection = connection_callable(base_mapper, state.obj())
-        elif not connection:
-            connection = uowtransaction.transaction.connection(
-                base_mapper)
 
         mapper = _state_mapper(state)
 
index 6eb7632130e1b6c18cd5bf7e83af991ba59db753..a54097b03e09d71571cd1c26019981c786bcb699 100644 (file)
@@ -1126,11 +1126,12 @@ class OneToManyTest(_fixtures.FixtureTest):
 
             ("UPDATE addresses SET user_id=:user_id "
              "WHERE addresses.id = :addresses_id",
-             {'user_id': None, 'addresses_id': a1.id}),
+             [
+                {'user_id': None, 'addresses_id': a1.id},
+                {'user_id': u1.id, 'addresses_id': a3.id}
+            ]),
 
-            ("UPDATE addresses SET user_id=:user_id "
-             "WHERE addresses.id = :addresses_id",
-             {'user_id': u1.id, 'addresses_id': a3.id})])
+            ])
 
     def test_child_move(self):
         """Moving a child from one parent to another, with a delete.
index 9c9296786efe50020da02c99ff101c7e8d52f5c7..c643e6a870b4a1d4ffe9c1ecb16ef0cb2cea2b02 100644 (file)
@@ -131,12 +131,10 @@ class RudimentaryFlushTest(UOWTest):
             CompiledSQL(
                 "UPDATE addresses SET user_id=:user_id WHERE "
                 "addresses.id = :addresses_id",
-                lambda ctx: [{'addresses_id': a1.id, 'user_id': None}]
-            ),
-            CompiledSQL(
-                "UPDATE addresses SET user_id=:user_id WHERE "
-                "addresses.id = :addresses_id",
-                lambda ctx: [{'addresses_id': a2.id, 'user_id': None}]
+                lambda ctx: [
+                    {'addresses_id': a1.id, 'user_id': None},
+                    {'addresses_id': a2.id, 'user_id': None}
+                ]
             ),
             CompiledSQL(
                 "DELETE FROM users WHERE users.id = :id",
@@ -240,12 +238,10 @@ class RudimentaryFlushTest(UOWTest):
             CompiledSQL(
                 "UPDATE addresses SET user_id=:user_id WHERE "
                 "addresses.id = :addresses_id",
-                lambda ctx: [{'addresses_id': a1.id, 'user_id': None}]
-            ),
-            CompiledSQL(
-                "UPDATE addresses SET user_id=:user_id WHERE "
-                "addresses.id = :addresses_id",
-                lambda ctx: [{'addresses_id': a2.id, 'user_id': None}]
+                lambda ctx: [
+                    {'addresses_id': a1.id, 'user_id': None},
+                    {'addresses_id': a2.id, 'user_id': None}
+                ]
             ),
             CompiledSQL(
                 "DELETE FROM users WHERE users.id = :id",
@@ -732,12 +728,11 @@ class SingleCycleTest(UOWTest):
             testing.db, sess.flush, AllOf(
                 CompiledSQL(
                     "UPDATE nodes SET parent_id=:parent_id "
-                    "WHERE nodes.id = :nodes_id", lambda ctx: {
-                        'nodes_id': n3.id, 'parent_id': None}),
-                CompiledSQL(
-                    "UPDATE nodes SET parent_id=:parent_id "
-                    "WHERE nodes.id = :nodes_id", lambda ctx: {
-                        'nodes_id': n2.id, 'parent_id': None}),
+                    "WHERE nodes.id = :nodes_id", lambda ctx: [
+                        {'nodes_id': n3.id, 'parent_id': None},
+                        {'nodes_id': n2.id, 'parent_id': None}
+                    ]
+                    )
             ),
             CompiledSQL(
                 "DELETE FROM nodes WHERE nodes.id = :id", lambda ctx: {