from itertools import chain
-from sqlalchemy import sql, util, log
+from sqlalchemy import sql, util, log, schema
from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.sql import util as sql_util
from sqlalchemy.sql import expression, visitors, operators
-from sqlalchemy.orm import attributes, interfaces, mapper, object_mapper
+from sqlalchemy.orm import attributes, interfaces, mapper, object_mapper, evaluator
from sqlalchemy.orm.util import _state_mapper, _is_mapped_class, \
_is_aliased_class, _entity_descriptor, _entity_info, _class_to_mapper, \
_orm_columns, AliasedClass, _orm_selectable, join as orm_join, ORMAdapter
if self._autoflush and not self._populate_existing:
self.session._autoflush()
return self.session.scalar(s, params=self._params, mapper=self._mapper_zero())
-
+
+ def delete(self, synchronize_session='evaluate'):
+ """EXPERIMENTAL"""
+ #TODO: lots of duplication and ifs - probably needs to be refactored to strategies
+ context = self._compile_context()
+ if len(context.statement.froms) != 1 or not isinstance(context.statement.froms[0], schema.Table):
+ raise sa_exc.ArgumentError("Only deletion via a single table query is currently supported")
+ primary_table = context.statement.froms[0]
+
+ session = self.session
+
+ if synchronize_session == 'evaluate':
+ try:
+ evaluator_compiler = evaluator.EvaluatorCompiler()
+ eval_condition = evaluator_compiler.process(self.whereclause)
+ except evaluator.UnevaluatableError:
+ synchronize_session = 'fetch'
+
+ delete_stmt = sql.delete(primary_table, context.whereclause)
+
+ if synchronize_session == 'fetch':
+ #TODO: use RETURNING when available
+ select_stmt = context.statement.with_only_columns(primary_table.primary_key)
+ matched_rows = session.execute(select_stmt).fetchall()
+
+ session.execute(delete_stmt)
+
+ if synchronize_session == 'evaluate':
+ target_cls = self._mapper_zero().class_
+
+ #TODO: detect when the where clause is a trivial primary key match
+ objs_to_expunge = [obj for (cls, pk, entity_name),obj in session.identity_map.iteritems()
+ if issubclass(cls, target_cls) and eval_condition(obj)]
+ for obj in objs_to_expunge:
+ session.expunge(obj)
+ elif synchronize_session == 'fetch':
+ target_mapper = self._mapper_zero()
+ for primary_key in matched_rows:
+ identity_key = target_mapper.identity_key_from_primary_key(list(primary_key))
+ if identity_key in session.identity_map:
+ session.expunge(session.identity_map[identity_key])
+
+ def update(self, values, synchronize_session='evaluate'):
+ """EXPERIMENTAL"""
+
+ #TODO: value keys need to be mapped to corresponding sql cols and instr.attr.s to string keys
+ #TODO: updates of manytoone relations need to be converted to fk assignments
+
+ context = self._compile_context()
+ if len(context.statement.froms) != 1 or not isinstance(context.statement.froms[0], schema.Table):
+ raise sa_exc.ArgumentError("Only update via a single table query is currently supported")
+ primary_table = context.statement.froms[0]
+
+ session = self.session
+
+ if synchronize_session == 'evaluate':
+ try:
+ evaluator_compiler = evaluator.EvaluatorCompiler()
+ eval_condition = evaluator_compiler.process(self.whereclause)
+
+ value_evaluators = {}
+ for key,value in values.items():
+ value_evaluators[key] = evaluator_compiler.process(expression._literal_as_binds(value))
+ except evaluator.UnevaluatableError:
+ synchronize_session = 'expire'
+
+ update_stmt = sql.update(primary_table, context.whereclause, values)
+
+ if synchronize_session == 'expire':
+ select_stmt = context.statement.with_only_columns(primary_table.primary_key)
+ matched_rows = session.execute(select_stmt).fetchall()
+
+ session.execute(update_stmt)
+
+ if synchronize_session == 'evaluate':
+ target_cls = self._mapper_zero().class_
+
+ for (cls, pk, entity_name),obj in session.identity_map.iteritems():
+ if issubclass(cls, target_cls) and eval_condition(obj):
+ for key,eval_value in value_evaluators.items():
+ obj.__dict__[key] = eval_value(obj)
+
+ elif synchronize_session == 'expire':
+ target_mapper = self._mapper_zero()
+
+ for primary_key in matched_rows:
+ identity_key = target_mapper.identity_key_from_primary_key(list(primary_key))
+ if identity_key in session.identity_map:
+ session.expire(session.identity_map[identity_key], values.keys())
+
+
def _compile_context(self, labels=True):
context = QueryContext(self)
b1
)
+class UpdateTest(_base.MappedTest):
+ def define_tables(self, metadata):
+ Table('users', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String),
+ Column('age', Integer))
+
+ def setup_classes(self):
+ class User(_base.ComparableEntity):
+ pass
+
+ @testing.resolve_artifact_names
+ def insert_data(self):
+ users.insert().execute([
+ dict(id=1, name='john', age=25),
+ dict(id=2, name='jack', age=47),
+ dict(id=3, name='jill', age=29),
+ dict(id=4, name='jane', age=37),
+ ])
+
+ @testing.resolve_artifact_names
+ def setup_mappers(self):
+ mapper(User, users)
+
+ @testing.resolve_artifact_names
+ def test_delete(self):
+ sess = create_session(bind=testing.db, autocommit=False)
+
+ john,jack,jill,jane = sess.query(User).order_by(User.id).all()
+ sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete()
+
+ assert john not in sess and jill not in sess
+
+ eq_(sess.query(User).order_by(User.id).all(), [jack,jane])
+
+ @testing.resolve_artifact_names
+ def test_delete_without_session_sync(self):
+ sess = create_session(bind=testing.db, autocommit=False)
+
+ john,jack,jill,jane = sess.query(User).order_by(User.id).all()
+ sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete(synchronize_session=False)
+
+ assert john in sess and jill in sess
+
+ eq_(sess.query(User).order_by(User.id).all(), [jack,jane])
+
+ @testing.resolve_artifact_names
+ def test_delete_with_fetch_strategy(self):
+ sess = create_session(bind=testing.db, autocommit=False)
+
+ john,jack,jill,jane = sess.query(User).order_by(User.id).all()
+ sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete(synchronize_session='fetch')
+
+ assert john not in sess and jill not in sess
+
+ eq_(sess.query(User).order_by(User.id).all(), [jack,jane])
+
+ @testing.resolve_artifact_names
+ def test_delete_fallback(self):
+ sess = create_session(bind=testing.db, autocommit=False)
+
+ john,jack,jill,jane = sess.query(User).order_by(User.id).all()
+ sess.query(User).filter(User.name == select([func.max(User.name)])).delete()
+
+ assert john not in sess
+
+ eq_(sess.query(User).order_by(User.id).all(), [jack,jill,jane])
+
+ @testing.resolve_artifact_names
+ def test_update(self):
+ sess = create_session(bind=testing.db, autocommit=False)
+
+ john,jack,jill,jane = sess.query(User).order_by(User.id).all()
+ sess.query(User).filter(User.age > 29).update({'age': User.age - 10})
+
+ eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27])
+ eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27]))
+
+ @testing.resolve_artifact_names
+ def test_update_with_expire_strategy(self):
+ sess = create_session(bind=testing.db, autocommit=False)
+
+ john,jack,jill,jane = sess.query(User).order_by(User.id).all()
+ sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='expire')
+
+ eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27])
+ eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27]))
if __name__ == '__main__':
testenv.main()