]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Route bulk update/delete exec through new Query._execute_crud method
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 7 Feb 2018 00:30:55 +0000 (19:30 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 4 Oct 2018 17:59:41 +0000 (13:59 -0400)
Added support for bulk :meth:`.Query.update` and :meth:`.Query.delete`
to the :class:`.ShardedQuery` class within the horiziontal sharding
extension.  This also adds an additional expansion hook to the
bulk update/delete methods :meth:`.Query._execute_crud`.

Fixes: #4196
Change-Id: I65f56458176497a8cbdd368f41b879881f06348b

doc/build/changelog/migration_13.rst
doc/build/changelog/unreleased_13/4196.rst [new file with mode: 0644]
lib/sqlalchemy/ext/horizontal_shard.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/query.py
test/ext/test_horizontal_shard.py
test/orm/test_update_delete.py

index 5a8e3ce05bc75e34d3bd4189a87be3b2d3e0d430..bc134c4373c3b2d898bced80bf83c4e93eb5a994 100644 (file)
@@ -39,6 +39,19 @@ along with that object's full lifecycle in memory::
 
 :ticket:`4257`
 
+.. _change_4196:
+
+Horizontal Sharding extension supports bulk update and delete methods
+---------------------------------------------------------------------
+
+The :class:`.ShardedQuery` extension object supports the :meth:`.Query.update`
+and :meth:`.Query.delete` bulk update/delete methods.    The ``query_chooser``
+callable is consulted when they are called in order to run the update/delete
+across multiple shards based on given criteria.
+
+
+:ticket:`4196`
+
 Key Behavioral Changes - ORM
 =============================
 
diff --git a/doc/build/changelog/unreleased_13/4196.rst b/doc/build/changelog/unreleased_13/4196.rst
new file mode 100644 (file)
index 0000000..c23002b
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: feature, ext
+    :tickets: 4196
+
+    Added support for bulk :meth:`.Query.update` and :meth:`.Query.delete`
+    to the :class:`.ShardedQuery` class within the horiziontal sharding
+    extension.  This also adds an additional expansion hook to the
+    bulk update/delete methods :meth:`.Query._execute_crud`.
+
+    .. seealso::
+
+        :ref:`change_4196`
index 6ef4c56126f0c60def015957366a711ac452c6b4..425d289637ad66adaf8238bb7a9ab184fa19e216 100644 (file)
@@ -64,6 +64,28 @@ class ShardedQuery(Query):
             # were done, this is where it would happen
             return iter(partial)
 
+    def _execute_crud(self, stmt, mapper):
+        def exec_for_shard(shard_id):
+            conn = self._connection_from_session(
+                mapper=mapper,
+                shard_id=shard_id,
+                clause=stmt,
+                close_with_result=True)
+            result = conn.execute(stmt, self._params)
+            return result
+
+        if self._shard_id is not None:
+            return exec_for_shard(self._shard_id)
+        else:
+            rowcount = 0
+            results = []
+            for shard_id in self.query_chooser(self):
+                result = exec_for_shard(shard_id)
+                rowcount += result.rowcount
+                results.append(result)
+
+            return ShardedResult(results, rowcount)
+
     def _identity_lookup(
             self, mapper, primary_key_identity, identity_token=None,
             lazy_loaded_from=None, **kw):
@@ -123,6 +145,29 @@ class ShardedQuery(Query):
             primary_key_identity, _db_load_fn, identity_token=identity_token)
 
 
+class ShardedResult(object):
+    """A value object that represents multiple :class:`.ResultProxy` objects.
+
+    This is used by the :meth:`.ShardedQuery._execute_crud` hook to return
+    an object that takes the place of the single :class:`.ResultProxy`.
+
+    Attribute include ``result_proxies``, which is a sequence of the
+    actual :class:`.ResultProxy` objects, as well as ``aggregate_rowcount``
+    or ``rowcount``, which is the sum of all the individual rowcount values.
+
+    .. versionadded::  1.3
+    """
+
+    __slots__ = ('result_proxies', 'aggregate_rowcount',)
+
+    def __init__(self, result_proxies, aggregate_rowcount):
+        self.result_proxies = result_proxies
+        self.aggregate_rowcount = aggregate_rowcount
+
+    @property
+    def rowcount(self):
+        return self.aggregate_rowcount
+
 class ShardedSession(Session):
     def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None,
                  query_cls=ShardedQuery, **kwargs):
index 95e26d83c174e506a628b5f36cd9dadbb8e95f53..afa3b50b9ad3deacf563987f0a418ef1b636e510 100644 (file)
@@ -1337,9 +1337,7 @@ class BulkUD(object):
         self._do_post()
 
     def _execute_stmt(self, stmt):
-        self.result = self.query.session.execute(
-            stmt, params=self.query._params,
-            mapper=self.mapper)
+        self.result = self.query._execute_crud(stmt, self.mapper)
         self.rowcount = self.result.rowcount
 
     @util.dependencies("sqlalchemy.orm.query")
index 7e7c93527bba9d2d99de7d499ed3ce8ae26bffda..e96996a3951f662f6107c53ed013b352cad774b2 100644 (file)
@@ -3016,6 +3016,12 @@ class Query(object):
         result = conn.execute(querycontext.statement, self._params)
         return loading.instances(querycontext.query, result, querycontext)
 
+    def _execute_crud(self, stmt, mapper):
+        conn = self._connection_from_session(
+            mapper=mapper, clause=stmt, close_with_result=True)
+
+        return conn.execute(stmt, self._params)
+
     def _get_bind_args(self, querycontext, fn, **kw):
         return fn(
             mapper=self._bind_mapper(),
index 2773be7d533a8abaf80a823ad6a26890293b0112..bd41594757dd130f7d89fe9b3d4e1460e8bab5f5 100644 (file)
@@ -321,6 +321,57 @@ class ShardTest(object):
         t = get_tokyo(sess2)
         eq_(t.city, tokyo.city)
 
+    def test_bulk_update(self):
+        sess = self._fixture_data()
+
+        eq_(
+            set(row.temperature for row in sess.query(Report.temperature)),
+            {80.0, 75.0, 85.0}
+        )
+
+        temps = sess.query(Report).all()
+        eq_(
+            set(t.temperature for t in temps),
+            {80.0, 75.0, 85.0}
+        )
+
+        sess.query(Report).filter(
+            Report.temperature >= 80).update(
+                {"temperature": Report.temperature + 6})
+
+        eq_(
+            set(row.temperature for row in sess.query(Report.temperature)),
+            {86.0, 75.0, 91.0}
+        )
+
+        # test synchronize session as well
+        eq_(
+            set(t.temperature for t in temps),
+            {86.0, 75.0, 91.0}
+        )
+
+    def test_bulk_delete(self):
+        sess = self._fixture_data()
+
+        temps = sess.query(Report).all()
+        eq_(
+            set(t.temperature for t in temps),
+            {80.0, 75.0, 85.0}
+        )
+
+        sess.query(Report).filter(
+            Report.temperature >= 80).delete()
+
+        eq_(
+            set(row.temperature for row in sess.query(Report.temperature)),
+            {75.0}
+        )
+
+        # test synchronize session as well
+        for t in temps:
+            assert inspect(t).deleted is (t.temperature >= 80)
+
+
 
 from sqlalchemy.testing import provision
 
index d1ccbb2e1a5af7b0edebaa667ad8d7ae8ff6f624..d1ea22dcc46e9ce3d70d6f959da4dbe4ac403474 100644 (file)
@@ -599,8 +599,9 @@ class UpdateDeleteTest(fixtures.MappedTest):
 
         # Do an update using unordered dict and check that the parameters used
         # are ordered in table order
-        with mock.patch.object(session, "execute") as exec_:
-            session.query(User).filter(User.id == 15).update(
+        q = session.query(User)
+        with mock.patch.object(q, "_execute_crud") as exec_:
+            q.filter(User.id == 15).update(
                 {'name': 'foob', 'id': 123})
             # Confirm that parameters are a dict instead of tuple or list
             params_type = type(exec_.mock_calls[0][1][0].parameters)
@@ -611,8 +612,9 @@ class UpdateDeleteTest(fixtures.MappedTest):
         session = Session()
 
         # Do update using a tuple and check that order is preserved
-        with mock.patch.object(session, "execute") as exec_:
-            session.query(User).filter(User.id == 15).update(
+        q = session.query(User)
+        with mock.patch.object(q, "_execute_crud") as exec_:
+            q.filter(User.id == 15).update(
                 (('id', 123), ('name', 'foob')),
                 update_args={"preserve_parameter_order": True})
             cols = [c.key
@@ -621,8 +623,9 @@ class UpdateDeleteTest(fixtures.MappedTest):
 
         # Now invert the order and use a list instead, and check that order is
         # also preserved
-        with mock.patch.object(session, "execute") as exec_:
-            session.query(User).filter(User.id == 15).update(
+        q = session.query(User)
+        with mock.patch.object(q, "_execute_crud") as exec_:
+            q.filter(User.id == 15).update(
                 [('name', 'foob'), ('id', 123)],
                 update_args={"preserve_parameter_order": True})
             cols = [c.key
@@ -951,11 +954,14 @@ class ExpressionUpdateTest(fixtures.MappedTest):
         Data = self.classes.Data
         session = testing.mock.Mock(wraps=Session())
         update_args = {"mysql_limit": 1}
-        query.Query(Data, session).update({Data.cnt: Data.cnt + 1},
-                                          update_args=update_args)
-        eq_(session.execute.call_count, 1)
-        args, kwargs = session.execute.call_args
-        eq_(len(args), 1)
+
+        q = session.query(Data)
+        with testing.mock.patch.object(q, '_execute_crud') as exec_:
+            q.update({Data.cnt: Data.cnt + 1},
+                     update_args=update_args)
+        eq_(exec_.call_count, 1)
+        args, kwargs = exec_.mock_calls[0][1:3]
+        eq_(len(args), 2)
         update_stmt = args[0]
         eq_(update_stmt.dialect_kwargs, update_args)