From: Ants Aasma Date: Tue, 2 Sep 2008 20:02:02 +0000 (+0000) Subject: Make Query.update and Query.delete return the amount of rows matched X-Git-Tag: rel_0_5rc1~23 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=920281ab55b407c9674759fa885797e1a9fff908;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Make Query.update and Query.delete return the amount of rows matched --- diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 6c20a1fe3c..0abf477ad6 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1254,6 +1254,8 @@ class Query(object): The expression evaluator currently doesn't account for differing string collations between the database and Python. + Returns the number of rows deleted, excluding any cascades. + Warning - this currently doesn't account for any foreign key/relation cascades. """ #TODO: lots of duplication and ifs - probably needs to be refactored to strategies @@ -1285,7 +1287,7 @@ class Query(object): if self._autoflush: session._autoflush() - session.execute(delete_stmt) + result = session.execute(delete_stmt) if synchronize_session == 'evaluate': target_cls = self._mapper_zero().class_ @@ -1302,6 +1304,8 @@ class Query(object): if identity_key in session.identity_map: session._remove_newly_deleted(attributes.instance_state(session.identity_map[identity_key])) + return result.rowcount + def update(self, values, synchronize_session='expire'): """Perform a bulk update query. @@ -1327,6 +1331,8 @@ class Query(object): The expression evaluator currently doesn't account for differing string collations between the database and Python. + Returns the number of rows matched by the update. + Warning - this currently doesn't account for any foreign key/relation cascades. """ @@ -1363,7 +1369,7 @@ class Query(object): if self._autoflush: session._autoflush() - session.execute(update_stmt) + result = session.execute(update_stmt) if synchronize_session == 'evaluate': target_cls = self._mapper_zero().class_ @@ -1392,6 +1398,8 @@ class Query(object): if identity_key in session.identity_map: session.expire(session.identity_map[identity_key], values.keys()) + return result.rowcount + def _compile_context(self, labels=True): context = QueryContext(self) diff --git a/test/orm/query.py b/test/orm/query.py index 8f0148e75c..26903edf1d 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -2355,5 +2355,22 @@ class UpdateDeleteTest(_base.MappedTest): eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) + @testing.resolve_artifact_names + def test_update_returns_rowcount(self): + sess = create_session(bind=testing.db, autocommit=False) + + rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age + 0}) + self.assertEquals(rowcount, 2) + + rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age - 10}) + self.assertEquals(rowcount, 2) + + @testing.resolve_artifact_names + def test_delete_returns_rowcount(self): + sess = create_session(bind=testing.db, autocommit=False) + + rowcount = sess.query(User).filter(User.age > 26).delete(synchronize_session=False) + self.assertEquals(rowcount, 3) + if __name__ == '__main__': testenv.main()