:ticket:`4257`
+.. _change_4196:
+
+Horizontal Sharding extension supports bulk update and delete methods
+---------------------------------------------------------------------
+
+The :class:`.ShardedQuery` extension object supports the :meth:`.Query.update`
+and :meth:`.Query.delete` bulk update/delete methods. The ``query_chooser``
+callable is consulted when they are called in order to run the update/delete
+across multiple shards based on given criteria.
+
+
+:ticket:`4196`
+
Key Behavioral Changes - ORM
=============================
--- /dev/null
+.. change::
+ :tags: feature, ext
+ :tickets: 4196
+
+ Added support for bulk :meth:`.Query.update` and :meth:`.Query.delete`
+ to the :class:`.ShardedQuery` class within the horiziontal sharding
+ extension. This also adds an additional expansion hook to the
+ bulk update/delete methods :meth:`.Query._execute_crud`.
+
+ .. seealso::
+
+ :ref:`change_4196`
# were done, this is where it would happen
return iter(partial)
+ def _execute_crud(self, stmt, mapper):
+ def exec_for_shard(shard_id):
+ conn = self._connection_from_session(
+ mapper=mapper,
+ shard_id=shard_id,
+ clause=stmt,
+ close_with_result=True)
+ result = conn.execute(stmt, self._params)
+ return result
+
+ if self._shard_id is not None:
+ return exec_for_shard(self._shard_id)
+ else:
+ rowcount = 0
+ results = []
+ for shard_id in self.query_chooser(self):
+ result = exec_for_shard(shard_id)
+ rowcount += result.rowcount
+ results.append(result)
+
+ return ShardedResult(results, rowcount)
+
def _identity_lookup(
self, mapper, primary_key_identity, identity_token=None,
lazy_loaded_from=None, **kw):
primary_key_identity, _db_load_fn, identity_token=identity_token)
+class ShardedResult(object):
+ """A value object that represents multiple :class:`.ResultProxy` objects.
+
+ This is used by the :meth:`.ShardedQuery._execute_crud` hook to return
+ an object that takes the place of the single :class:`.ResultProxy`.
+
+ Attribute include ``result_proxies``, which is a sequence of the
+ actual :class:`.ResultProxy` objects, as well as ``aggregate_rowcount``
+ or ``rowcount``, which is the sum of all the individual rowcount values.
+
+ .. versionadded:: 1.3
+ """
+
+ __slots__ = ('result_proxies', 'aggregate_rowcount',)
+
+ def __init__(self, result_proxies, aggregate_rowcount):
+ self.result_proxies = result_proxies
+ self.aggregate_rowcount = aggregate_rowcount
+
+ @property
+ def rowcount(self):
+ return self.aggregate_rowcount
+
class ShardedSession(Session):
def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None,
query_cls=ShardedQuery, **kwargs):
self._do_post()
def _execute_stmt(self, stmt):
- self.result = self.query.session.execute(
- stmt, params=self.query._params,
- mapper=self.mapper)
+ self.result = self.query._execute_crud(stmt, self.mapper)
self.rowcount = self.result.rowcount
@util.dependencies("sqlalchemy.orm.query")
result = conn.execute(querycontext.statement, self._params)
return loading.instances(querycontext.query, result, querycontext)
+ def _execute_crud(self, stmt, mapper):
+ conn = self._connection_from_session(
+ mapper=mapper, clause=stmt, close_with_result=True)
+
+ return conn.execute(stmt, self._params)
+
def _get_bind_args(self, querycontext, fn, **kw):
return fn(
mapper=self._bind_mapper(),
t = get_tokyo(sess2)
eq_(t.city, tokyo.city)
+ def test_bulk_update(self):
+ sess = self._fixture_data()
+
+ eq_(
+ set(row.temperature for row in sess.query(Report.temperature)),
+ {80.0, 75.0, 85.0}
+ )
+
+ temps = sess.query(Report).all()
+ eq_(
+ set(t.temperature for t in temps),
+ {80.0, 75.0, 85.0}
+ )
+
+ sess.query(Report).filter(
+ Report.temperature >= 80).update(
+ {"temperature": Report.temperature + 6})
+
+ eq_(
+ set(row.temperature for row in sess.query(Report.temperature)),
+ {86.0, 75.0, 91.0}
+ )
+
+ # test synchronize session as well
+ eq_(
+ set(t.temperature for t in temps),
+ {86.0, 75.0, 91.0}
+ )
+
+ def test_bulk_delete(self):
+ sess = self._fixture_data()
+
+ temps = sess.query(Report).all()
+ eq_(
+ set(t.temperature for t in temps),
+ {80.0, 75.0, 85.0}
+ )
+
+ sess.query(Report).filter(
+ Report.temperature >= 80).delete()
+
+ eq_(
+ set(row.temperature for row in sess.query(Report.temperature)),
+ {75.0}
+ )
+
+ # test synchronize session as well
+ for t in temps:
+ assert inspect(t).deleted is (t.temperature >= 80)
+
+
from sqlalchemy.testing import provision
# Do an update using unordered dict and check that the parameters used
# are ordered in table order
- with mock.patch.object(session, "execute") as exec_:
- session.query(User).filter(User.id == 15).update(
+ q = session.query(User)
+ with mock.patch.object(q, "_execute_crud") as exec_:
+ q.filter(User.id == 15).update(
{'name': 'foob', 'id': 123})
# Confirm that parameters are a dict instead of tuple or list
params_type = type(exec_.mock_calls[0][1][0].parameters)
session = Session()
# Do update using a tuple and check that order is preserved
- with mock.patch.object(session, "execute") as exec_:
- session.query(User).filter(User.id == 15).update(
+ q = session.query(User)
+ with mock.patch.object(q, "_execute_crud") as exec_:
+ q.filter(User.id == 15).update(
(('id', 123), ('name', 'foob')),
update_args={"preserve_parameter_order": True})
cols = [c.key
# Now invert the order and use a list instead, and check that order is
# also preserved
- with mock.patch.object(session, "execute") as exec_:
- session.query(User).filter(User.id == 15).update(
+ q = session.query(User)
+ with mock.patch.object(q, "_execute_crud") as exec_:
+ q.filter(User.id == 15).update(
[('name', 'foob'), ('id', 123)],
update_args={"preserve_parameter_order": True})
cols = [c.key
Data = self.classes.Data
session = testing.mock.Mock(wraps=Session())
update_args = {"mysql_limit": 1}
- query.Query(Data, session).update({Data.cnt: Data.cnt + 1},
- update_args=update_args)
- eq_(session.execute.call_count, 1)
- args, kwargs = session.execute.call_args
- eq_(len(args), 1)
+
+ q = session.query(Data)
+ with testing.mock.patch.object(q, '_execute_crud') as exec_:
+ q.update({Data.cnt: Data.cnt + 1},
+ update_args=update_args)
+ eq_(exec_.call_count, 1)
+ args, kwargs = exec_.mock_calls[0][1:3]
+ eq_(len(args), 2)
update_stmt = args[0]
eq_(update_stmt.dialect_kwargs, update_args)