]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Make Query.update and Query.delete return the amount of rows matched
authorAnts Aasma <ants.aasma@gmail.com>
Tue, 2 Sep 2008 20:02:02 +0000 (20:02 +0000)
committerAnts Aasma <ants.aasma@gmail.com>
Tue, 2 Sep 2008 20:02:02 +0000 (20:02 +0000)
lib/sqlalchemy/orm/query.py
test/orm/query.py

index 6c20a1fe3cee189eadedf7f8d0b55b93ab9996e2..0abf477ad636f4a6b8b28831c63747820b5b3785 100644 (file)
@@ -1254,6 +1254,8 @@ class Query(object):
           The expression evaluator currently doesn't account for differing string
           collations between the database and Python.
 
+        Returns the number of rows deleted, excluding any cascades.
+
         Warning - this currently doesn't account for any foreign key/relation cascades.
         """
         #TODO: lots of duplication and ifs - probably needs to be refactored to strategies
@@ -1285,7 +1287,7 @@ class Query(object):
 
         if self._autoflush:
             session._autoflush()
-        session.execute(delete_stmt)
+        result = session.execute(delete_stmt)
 
         if synchronize_session == 'evaluate':
             target_cls = self._mapper_zero().class_
@@ -1302,6 +1304,8 @@ class Query(object):
                 if identity_key in session.identity_map:
                     session._remove_newly_deleted(attributes.instance_state(session.identity_map[identity_key]))
 
+        return result.rowcount
+
     def update(self, values, synchronize_session='expire'):
         """Perform a bulk update query.
 
@@ -1327,6 +1331,8 @@ class Query(object):
           The expression evaluator currently doesn't account for differing string
           collations between the database and Python.
 
+        Returns the number of rows matched by the update.
+
         Warning - this currently doesn't account for any foreign key/relation cascades.
         """
 
@@ -1363,7 +1369,7 @@ class Query(object):
 
         if self._autoflush:
             session._autoflush()
-        session.execute(update_stmt)
+        result = session.execute(update_stmt)
 
         if synchronize_session == 'evaluate':
             target_cls = self._mapper_zero().class_
@@ -1392,6 +1398,8 @@ class Query(object):
                 if identity_key in session.identity_map:
                     session.expire(session.identity_map[identity_key], values.keys())
 
+        return result.rowcount
+
 
     def _compile_context(self, labels=True):
         context = QueryContext(self)
index 8f0148e75cc8b78c2ab517d1fc7e445334aa9509..26903edf1d1265a9171f2e662e21af268474d801 100644 (file)
@@ -2355,5 +2355,22 @@ class UpdateDeleteTest(_base.MappedTest):
         eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27])
         eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27]))
 
+    @testing.resolve_artifact_names
+    def test_update_returns_rowcount(self):
+        sess = create_session(bind=testing.db, autocommit=False)
+
+        rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age + 0})
+        self.assertEquals(rowcount, 2)
+
+        rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age - 10})
+        self.assertEquals(rowcount, 2)
+
+    @testing.resolve_artifact_names
+    def test_delete_returns_rowcount(self):
+        sess = create_session(bind=testing.db, autocommit=False)
+
+        rowcount = sess.query(User).filter(User.age > 26).delete(synchronize_session=False)
+        self.assertEquals(rowcount, 3)
+
 if __name__ == '__main__':
     testenv.main()