From: Mike Bayer Date: Wed, 7 Feb 2018 00:30:55 +0000 (-0500) Subject: Route bulk update/delete exec through new Query._execute_crud method X-Git-Tag: rel_1_3_0b1~51^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3081269e6f1fc51d8d5cfc5120dd10ee2872e871;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Route bulk update/delete exec through new Query._execute_crud method Added support for bulk :meth:`.Query.update` and :meth:`.Query.delete` to the :class:`.ShardedQuery` class within the horiziontal sharding extension. This also adds an additional expansion hook to the bulk update/delete methods :meth:`.Query._execute_crud`. Fixes: #4196 Change-Id: I65f56458176497a8cbdd368f41b879881f06348b --- diff --git a/doc/build/changelog/migration_13.rst b/doc/build/changelog/migration_13.rst index 5a8e3ce05b..bc134c4373 100644 --- a/doc/build/changelog/migration_13.rst +++ b/doc/build/changelog/migration_13.rst @@ -39,6 +39,19 @@ along with that object's full lifecycle in memory:: :ticket:`4257` +.. _change_4196: + +Horizontal Sharding extension supports bulk update and delete methods +--------------------------------------------------------------------- + +The :class:`.ShardedQuery` extension object supports the :meth:`.Query.update` +and :meth:`.Query.delete` bulk update/delete methods. The ``query_chooser`` +callable is consulted when they are called in order to run the update/delete +across multiple shards based on given criteria. + + +:ticket:`4196` + Key Behavioral Changes - ORM ============================= diff --git a/doc/build/changelog/unreleased_13/4196.rst b/doc/build/changelog/unreleased_13/4196.rst new file mode 100644 index 0000000000..c23002bafb --- /dev/null +++ b/doc/build/changelog/unreleased_13/4196.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: feature, ext + :tickets: 4196 + + Added support for bulk :meth:`.Query.update` and :meth:`.Query.delete` + to the :class:`.ShardedQuery` class within the horiziontal sharding + extension. This also adds an additional expansion hook to the + bulk update/delete methods :meth:`.Query._execute_crud`. + + .. seealso:: + + :ref:`change_4196` diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 6ef4c56126..425d289637 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -64,6 +64,28 @@ class ShardedQuery(Query): # were done, this is where it would happen return iter(partial) + def _execute_crud(self, stmt, mapper): + def exec_for_shard(shard_id): + conn = self._connection_from_session( + mapper=mapper, + shard_id=shard_id, + clause=stmt, + close_with_result=True) + result = conn.execute(stmt, self._params) + return result + + if self._shard_id is not None: + return exec_for_shard(self._shard_id) + else: + rowcount = 0 + results = [] + for shard_id in self.query_chooser(self): + result = exec_for_shard(shard_id) + rowcount += result.rowcount + results.append(result) + + return ShardedResult(results, rowcount) + def _identity_lookup( self, mapper, primary_key_identity, identity_token=None, lazy_loaded_from=None, **kw): @@ -123,6 +145,29 @@ class ShardedQuery(Query): primary_key_identity, _db_load_fn, identity_token=identity_token) +class ShardedResult(object): + """A value object that represents multiple :class:`.ResultProxy` objects. + + This is used by the :meth:`.ShardedQuery._execute_crud` hook to return + an object that takes the place of the single :class:`.ResultProxy`. + + Attribute include ``result_proxies``, which is a sequence of the + actual :class:`.ResultProxy` objects, as well as ``aggregate_rowcount`` + or ``rowcount``, which is the sum of all the individual rowcount values. + + .. versionadded:: 1.3 + """ + + __slots__ = ('result_proxies', 'aggregate_rowcount',) + + def __init__(self, result_proxies, aggregate_rowcount): + self.result_proxies = result_proxies + self.aggregate_rowcount = aggregate_rowcount + + @property + def rowcount(self): + return self.aggregate_rowcount + class ShardedSession(Session): def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, query_cls=ShardedQuery, **kwargs): diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 95e26d83c1..afa3b50b9a 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -1337,9 +1337,7 @@ class BulkUD(object): self._do_post() def _execute_stmt(self, stmt): - self.result = self.query.session.execute( - stmt, params=self.query._params, - mapper=self.mapper) + self.result = self.query._execute_crud(stmt, self.mapper) self.rowcount = self.result.rowcount @util.dependencies("sqlalchemy.orm.query") diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 7e7c93527b..e96996a395 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -3016,6 +3016,12 @@ class Query(object): result = conn.execute(querycontext.statement, self._params) return loading.instances(querycontext.query, result, querycontext) + def _execute_crud(self, stmt, mapper): + conn = self._connection_from_session( + mapper=mapper, clause=stmt, close_with_result=True) + + return conn.execute(stmt, self._params) + def _get_bind_args(self, querycontext, fn, **kw): return fn( mapper=self._bind_mapper(), diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index 2773be7d53..bd41594757 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -321,6 +321,57 @@ class ShardTest(object): t = get_tokyo(sess2) eq_(t.city, tokyo.city) + def test_bulk_update(self): + sess = self._fixture_data() + + eq_( + set(row.temperature for row in sess.query(Report.temperature)), + {80.0, 75.0, 85.0} + ) + + temps = sess.query(Report).all() + eq_( + set(t.temperature for t in temps), + {80.0, 75.0, 85.0} + ) + + sess.query(Report).filter( + Report.temperature >= 80).update( + {"temperature": Report.temperature + 6}) + + eq_( + set(row.temperature for row in sess.query(Report.temperature)), + {86.0, 75.0, 91.0} + ) + + # test synchronize session as well + eq_( + set(t.temperature for t in temps), + {86.0, 75.0, 91.0} + ) + + def test_bulk_delete(self): + sess = self._fixture_data() + + temps = sess.query(Report).all() + eq_( + set(t.temperature for t in temps), + {80.0, 75.0, 85.0} + ) + + sess.query(Report).filter( + Report.temperature >= 80).delete() + + eq_( + set(row.temperature for row in sess.query(Report.temperature)), + {75.0} + ) + + # test synchronize session as well + for t in temps: + assert inspect(t).deleted is (t.temperature >= 80) + + from sqlalchemy.testing import provision diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py index d1ccbb2e1a..d1ea22dcc4 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/test_update_delete.py @@ -599,8 +599,9 @@ class UpdateDeleteTest(fixtures.MappedTest): # Do an update using unordered dict and check that the parameters used # are ordered in table order - with mock.patch.object(session, "execute") as exec_: - session.query(User).filter(User.id == 15).update( + q = session.query(User) + with mock.patch.object(q, "_execute_crud") as exec_: + q.filter(User.id == 15).update( {'name': 'foob', 'id': 123}) # Confirm that parameters are a dict instead of tuple or list params_type = type(exec_.mock_calls[0][1][0].parameters) @@ -611,8 +612,9 @@ class UpdateDeleteTest(fixtures.MappedTest): session = Session() # Do update using a tuple and check that order is preserved - with mock.patch.object(session, "execute") as exec_: - session.query(User).filter(User.id == 15).update( + q = session.query(User) + with mock.patch.object(q, "_execute_crud") as exec_: + q.filter(User.id == 15).update( (('id', 123), ('name', 'foob')), update_args={"preserve_parameter_order": True}) cols = [c.key @@ -621,8 +623,9 @@ class UpdateDeleteTest(fixtures.MappedTest): # Now invert the order and use a list instead, and check that order is # also preserved - with mock.patch.object(session, "execute") as exec_: - session.query(User).filter(User.id == 15).update( + q = session.query(User) + with mock.patch.object(q, "_execute_crud") as exec_: + q.filter(User.id == 15).update( [('name', 'foob'), ('id', 123)], update_args={"preserve_parameter_order": True}) cols = [c.key @@ -951,11 +954,14 @@ class ExpressionUpdateTest(fixtures.MappedTest): Data = self.classes.Data session = testing.mock.Mock(wraps=Session()) update_args = {"mysql_limit": 1} - query.Query(Data, session).update({Data.cnt: Data.cnt + 1}, - update_args=update_args) - eq_(session.execute.call_count, 1) - args, kwargs = session.execute.call_args - eq_(len(args), 1) + + q = session.query(Data) + with testing.mock.patch.object(q, '_execute_crud') as exec_: + q.update({Data.cnt: Data.cnt + 1}, + update_args=update_args) + eq_(exec_.call_count, 1) + args, kwargs = exec_.mock_calls[0][1:3] + eq_(len(args), 2) update_stmt = args[0] eq_(update_stmt.dialect_kwargs, update_args)