From: Ants Aasma Date: Thu, 29 May 2008 02:12:17 +0000 (+0000) Subject: Add delete and update methods to query X-Git-Tag: rel_0_5beta1~15 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=87718fa82b498425cc754f1e334c0e1219470c72;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add delete and update methods to query --- diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 429fc40c1d..c2e1550af4 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -20,12 +20,12 @@ iterable result sets. from itertools import chain -from sqlalchemy import sql, util, log +from sqlalchemy import sql, util, log, schema from sqlalchemy import exc as sa_exc from sqlalchemy.orm import exc as orm_exc from sqlalchemy.sql import util as sql_util from sqlalchemy.sql import expression, visitors, operators -from sqlalchemy.orm import attributes, interfaces, mapper, object_mapper +from sqlalchemy.orm import attributes, interfaces, mapper, object_mapper, evaluator from sqlalchemy.orm.util import _state_mapper, _is_mapped_class, \ _is_aliased_class, _entity_descriptor, _entity_info, _class_to_mapper, \ _orm_columns, AliasedClass, _orm_selectable, join as orm_join, ORMAdapter @@ -1239,7 +1239,97 @@ class Query(object): if self._autoflush and not self._populate_existing: self.session._autoflush() return self.session.scalar(s, params=self._params, mapper=self._mapper_zero()) - + + def delete(self, synchronize_session='evaluate'): + """EXPERIMENTAL""" + #TODO: lots of duplication and ifs - probably needs to be refactored to strategies + context = self._compile_context() + if len(context.statement.froms) != 1 or not isinstance(context.statement.froms[0], schema.Table): + raise sa_exc.ArgumentError("Only deletion via a single table query is currently supported") + primary_table = context.statement.froms[0] + + session = self.session + + if synchronize_session == 'evaluate': + try: + evaluator_compiler = evaluator.EvaluatorCompiler() + eval_condition = evaluator_compiler.process(self.whereclause) + except evaluator.UnevaluatableError: + synchronize_session = 'fetch' + + delete_stmt = sql.delete(primary_table, context.whereclause) + + if synchronize_session == 'fetch': + #TODO: use RETURNING when available + select_stmt = context.statement.with_only_columns(primary_table.primary_key) + matched_rows = session.execute(select_stmt).fetchall() + + session.execute(delete_stmt) + + if synchronize_session == 'evaluate': + target_cls = self._mapper_zero().class_ + + #TODO: detect when the where clause is a trivial primary key match + objs_to_expunge = [obj for (cls, pk, entity_name),obj in session.identity_map.iteritems() + if issubclass(cls, target_cls) and eval_condition(obj)] + for obj in objs_to_expunge: + session.expunge(obj) + elif synchronize_session == 'fetch': + target_mapper = self._mapper_zero() + for primary_key in matched_rows: + identity_key = target_mapper.identity_key_from_primary_key(list(primary_key)) + if identity_key in session.identity_map: + session.expunge(session.identity_map[identity_key]) + + def update(self, values, synchronize_session='evaluate'): + """EXPERIMENTAL""" + + #TODO: value keys need to be mapped to corresponding sql cols and instr.attr.s to string keys + #TODO: updates of manytoone relations need to be converted to fk assignments + + context = self._compile_context() + if len(context.statement.froms) != 1 or not isinstance(context.statement.froms[0], schema.Table): + raise sa_exc.ArgumentError("Only update via a single table query is currently supported") + primary_table = context.statement.froms[0] + + session = self.session + + if synchronize_session == 'evaluate': + try: + evaluator_compiler = evaluator.EvaluatorCompiler() + eval_condition = evaluator_compiler.process(self.whereclause) + + value_evaluators = {} + for key,value in values.items(): + value_evaluators[key] = evaluator_compiler.process(expression._literal_as_binds(value)) + except evaluator.UnevaluatableError: + synchronize_session = 'expire' + + update_stmt = sql.update(primary_table, context.whereclause, values) + + if synchronize_session == 'expire': + select_stmt = context.statement.with_only_columns(primary_table.primary_key) + matched_rows = session.execute(select_stmt).fetchall() + + session.execute(update_stmt) + + if synchronize_session == 'evaluate': + target_cls = self._mapper_zero().class_ + + for (cls, pk, entity_name),obj in session.identity_map.iteritems(): + if issubclass(cls, target_cls) and eval_condition(obj): + for key,eval_value in value_evaluators.items(): + obj.__dict__[key] = eval_value(obj) + + elif synchronize_session == 'expire': + target_mapper = self._mapper_zero() + + for primary_key in matched_rows: + identity_key = target_mapper.identity_key_from_primary_key(list(primary_key)) + if identity_key in session.identity_map: + session.expire(session.identity_map[identity_key], values.keys()) + + def _compile_context(self, labels=True): context = QueryContext(self) diff --git a/test/orm/query.py b/test/orm/query.py index e23b0aadd5..53927d65dc 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -2123,6 +2123,93 @@ class TestOverlyEagerEquivalentCols(_base.MappedTest): b1 ) +class UpdateTest(_base.MappedTest): + def define_tables(self, metadata): + Table('users', metadata, + Column('id', Integer, primary_key=True), + Column('name', String), + Column('age', Integer)) + + def setup_classes(self): + class User(_base.ComparableEntity): + pass + + @testing.resolve_artifact_names + def insert_data(self): + users.insert().execute([ + dict(id=1, name='john', age=25), + dict(id=2, name='jack', age=47), + dict(id=3, name='jill', age=29), + dict(id=4, name='jane', age=37), + ]) + + @testing.resolve_artifact_names + def setup_mappers(self): + mapper(User, users) + + @testing.resolve_artifact_names + def test_delete(self): + sess = create_session(bind=testing.db, autocommit=False) + + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete() + + assert john not in sess and jill not in sess + + eq_(sess.query(User).order_by(User.id).all(), [jack,jane]) + + @testing.resolve_artifact_names + def test_delete_without_session_sync(self): + sess = create_session(bind=testing.db, autocommit=False) + + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete(synchronize_session=False) + + assert john in sess and jill in sess + + eq_(sess.query(User).order_by(User.id).all(), [jack,jane]) + + @testing.resolve_artifact_names + def test_delete_with_fetch_strategy(self): + sess = create_session(bind=testing.db, autocommit=False) + + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete(synchronize_session='fetch') + + assert john not in sess and jill not in sess + + eq_(sess.query(User).order_by(User.id).all(), [jack,jane]) + + @testing.resolve_artifact_names + def test_delete_fallback(self): + sess = create_session(bind=testing.db, autocommit=False) + + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + sess.query(User).filter(User.name == select([func.max(User.name)])).delete() + + assert john not in sess + + eq_(sess.query(User).order_by(User.id).all(), [jack,jill,jane]) + + @testing.resolve_artifact_names + def test_update(self): + sess = create_session(bind=testing.db, autocommit=False) + + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + sess.query(User).filter(User.age > 29).update({'age': User.age - 10}) + + eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) + eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) + + @testing.resolve_artifact_names + def test_update_with_expire_strategy(self): + sess = create_session(bind=testing.db, autocommit=False) + + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='expire') + + eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) + eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) if __name__ == '__main__': testenv.main()