From: Mike Bayer Date: Tue, 2 Sep 2008 17:57:35 +0000 (+0000) Subject: - AttributeListener has been refined such that the event X-Git-Tag: rel_0_5rc1~26 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3e25e6e6b05c39b15deda65921d411ec8cb341ae;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - AttributeListener has been refined such that the event is fired before the mutation actually occurs. Addtionally, the append() and set() methods must now return the given value, which is used as the value to be used in the mutation operation. This allows creation of validating AttributeListeners which raise before the action actually occurs, and which can change the given value into something else before its used. A new example "validate_attributes.py" shows one such recipe for doing this. AttributeListener helper functions are also on the way. --- diff --git a/CHANGES b/CHANGES index 1204a347c2..5d91ac46c6 100644 --- a/CHANGES +++ b/CHANGES @@ -33,6 +33,17 @@ CHANGES clause will appear in the WHERE clause of the query as well since this discrimination has multiple trigger points. + - AttributeListener has been refined such that the event + is fired before the mutation actually occurs. Addtionally, + the append() and set() methods must now return the given value, + which is used as the value to be used in the mutation operation. + This allows creation of validating AttributeListeners which + raise before the action actually occurs, and which can change + the given value into something else before its used. + A new example "validate_attributes.py" shows one such recipe + for doing this. AttributeListener helper functions are + also on the way. + - class.someprop.in_() raises NotImplementedError pending the implementation of "in_" for relation [ticket:1140] diff --git a/examples/custom_attributes/listen_for_events.py b/examples/custom_attributes/listen_for_events.py index 71f8bbadec..e980e61edc 100644 --- a/examples/custom_attributes/listen_for_events.py +++ b/examples/custom_attributes/listen_for_events.py @@ -7,9 +7,9 @@ from sqlalchemy.orm.interfaces import AttributeExtension, InstrumentationManager class InstallListeners(InstrumentationManager): def instrument_attribute(self, class_, key, inst): - """Add an event listener to all InstrumentedAttributes.""" + """Add an event listener to an InstrumentedAttribute.""" - inst.impl.extensions.append(AttributeListener(key)) + inst.impl.extensions.insert(0, AttributeListener(key)) return super(InstallListeners, self).instrument_attribute(class_, key, inst) class AttributeListener(AttributeExtension): @@ -25,12 +25,14 @@ class AttributeListener(AttributeExtension): def append(self, state, value, initiator): self._report(state, value, None, "appended") + return value def remove(self, state, value, initiator): self._report(state, value, None, "removed") def set(self, state, value, oldvalue, initiator): self._report(state, value, oldvalue, "set") + return value def _report(self, state, value, oldvalue, verb): state.obj().receive_change_event(verb, self.key, value, oldvalue) diff --git a/examples/custom_attributes/validate_attributes.py b/examples/custom_attributes/validate_attributes.py new file mode 100644 index 0000000000..63b2529fdd --- /dev/null +++ b/examples/custom_attributes/validate_attributes.py @@ -0,0 +1,117 @@ +""" +Illustrates how to use AttributeExtension to create attribute validators. + +""" + +from sqlalchemy.orm.interfaces import AttributeExtension, InstrumentationManager + +class InstallValidators(InstrumentationManager): + """Searches a class for methods with a '_validates' attribute and assembles Validators.""" + + def __init__(self, cls): + self.validators = {} + for k in dir(cls): + item = getattr(cls, k) + if hasattr(item, '_validates'): + self.validators[item._validates] = item + + def instrument_attribute(self, class_, key, inst): + """Add an event listener to an InstrumentedAttribute.""" + + if key in self.validators: + inst.impl.extensions.insert(0, Validator(key, self.validators[key])) + return super(InstallValidators, self).instrument_attribute(class_, key, inst) + +class Validator(AttributeExtension): + """Validates an attribute, given the key and a validation function.""" + + def __init__(self, key, validator): + self.key = key + self.validator = validator + + def append(self, state, value, initiator): + return self.validator(state.obj(), value) + + def set(self, state, value, oldvalue, initiator): + return self.validator(state.obj(), value) + +def validates(key): + """Mark a method as validating a named attribute.""" + + def wrap(fn): + fn._validates = key + return fn + return wrap + +if __name__ == '__main__': + + from sqlalchemy import * + from sqlalchemy.orm import * + from sqlalchemy.ext.declarative import declarative_base + import datetime + + Base = declarative_base(engine=create_engine('sqlite://', echo=True)) + Base.__sa_instrumentation_manager__ = InstallValidators + + class MyMappedClass(Base): + __tablename__ = "mytable" + + id = Column(Integer, primary_key=True) + date = Column(Date) + related_id = Column(Integer, ForeignKey("related.id")) + related = relation("Related", backref="mapped") + + @validates('date') + def check_date(self, value): + if isinstance(value, str): + m, d, y = [int(x) for x in value.split('/')] + return datetime.date(y, m, d) + else: + assert isinstance(value, datetime.date) + return value + + @validates('related') + def check_related(self, value): + assert value.data == 'r1' + return value + + def __str__(self): + return "MyMappedClass(date=%r)" % self.date + + class Related(Base): + __tablename__ = "related" + + id = Column(Integer, primary_key=True) + data = Column(String(50)) + + def __str__(self): + return "Related(data=%r)" % self.data + + Base.metadata.create_all() + session = sessionmaker()() + + r1 = Related(data='r1') + r2 = Related(data='r2') + m1 = MyMappedClass(date='5/2/2005', related=r1) + m2 = MyMappedClass(date=datetime.date(2008, 10, 15)) + r1.mapped.append(m2) + + try: + m1.date = "this is not a date" + except: + pass + assert m1.date == datetime.date(2005, 5, 2) + + try: + m2.related = r2 + except: + pass + assert m2.related is r1 + + session.add(m1) + session.commit() + assert session.query(MyMappedClass.date).order_by(MyMappedClass.date).all() == [ + (datetime.date(2005, 5, 2),), + (datetime.date(2008, 10, 15),) + ] + \ No newline at end of file diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index ddebc563f0..17fea7854f 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -382,8 +382,8 @@ class ScalarAttributeImpl(AttributeImpl): state.modified_event(self, False, old) if self.extensions: - del state.dict[self.key] self.fire_remove_event(state, old, None) + del state.dict[self.key] else: del state.dict[self.key] @@ -403,14 +403,15 @@ class ScalarAttributeImpl(AttributeImpl): state.modified_event(self, False, old) if self.extensions: + value = self.fire_replace_event(state, value, old, initiator) state.dict[self.key] = value - self.fire_replace_event(state, value, old, initiator) else: state.dict[self.key] = value def fire_replace_event(self, state, value, previous, initiator): for ext in self.extensions: - ext.set(state, value, previous, initiator or self) + value = ext.set(state, value, previous, initiator or self) + return value def fire_remove_event(self, state, value, initiator): for ext in self.extensions: @@ -457,8 +458,8 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl): if self.extensions: old = self.get(state) + value = self.fire_replace_event(state, value, old, initiator) state.dict[self.key] = value - self.fire_replace_event(state, value, old, initiator) else: state.dict[self.key] = value @@ -483,9 +484,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): def delete(self, state): old = self.get(state) - # TODO: catch key errors, convert to attributeerror? - del state.dict[self.key] self.fire_remove_event(state, old, self) + del state.dict[self.key] def get_history(self, state, passive=False): if self.key in state.dict: @@ -510,8 +510,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): # may want to add options to allow the get() here to be passive old = self.get(state) + value = self.fire_replace_event(state, value, old, initiator) state.dict[self.key] = value - self.fire_replace_event(state, value, old, initiator) def fire_remove_event(self, state, value, initiator): state.modified_event(self, False, value) @@ -532,7 +532,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): self.sethasparent(instance_state(previous), False) for ext in self.extensions: - ext.set(state, value, previous, initiator or self) + value = ext.set(state, value, previous, initiator or self) + return value class CollectionAttributeImpl(AttributeImpl): @@ -582,7 +583,8 @@ class CollectionAttributeImpl(AttributeImpl): self.sethasparent(instance_state(value), True) for ext in self.extensions: - ext.append(state, value, initiator or self) + value = ext.append(state, value, initiator or self) + return value def fire_pre_remove_event(self, state, initiator): state.modified_event(self, True, NEVER_SET, passive=True) @@ -624,8 +626,8 @@ class CollectionAttributeImpl(AttributeImpl): collection = self.get_collection(state, passive=passive) if collection is PASSIVE_NORESULT: + value = self.fire_append_event(state, value, initiator) state.get_pending(self.key).append(value) - self.fire_append_event(state, value, initiator) else: collection.append_with_event(value, initiator) @@ -635,8 +637,8 @@ class CollectionAttributeImpl(AttributeImpl): collection = self.get_collection(state, passive=passive) if collection is PASSIVE_NORESULT: - state.get_pending(self.key).remove(value) self.fire_remove_event(state, value, initiator) + state.get_pending(self.key).remove(value) else: collection.remove_with_event(value, initiator) @@ -745,7 +747,7 @@ class GenericBackrefExtension(interfaces.AttributeExtension): def set(self, state, child, oldchild, initiator): if oldchild is child: - return + return child if oldchild is not None: # With lazy=None, there's no guarantee that the full collection is # present when updating via a backref. @@ -758,11 +760,13 @@ class GenericBackrefExtension(interfaces.AttributeExtension): if child is not None: new_state = instance_state(child) new_state.get_impl(self.key).append(new_state, state.obj(), initiator, passive=True) - + return child + def append(self, state, child, initiator): child_state = instance_state(child) child_state.get_impl(self.key).append(child_state, state.obj(), initiator, passive=True) - + return child + def remove(self, state, child, initiator): if child is not None: child_state = instance_state(child) diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index f8570dd5fc..497ef59411 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -584,7 +584,9 @@ class CollectionAdapter(object): """ if initiator is not False and item is not None: - self.attr.fire_append_event(self.owner_state, item, initiator) + return self.attr.fire_append_event(self.owner_state, item, initiator) + else: + return item def fire_remove_event(self, item, initiator=None): """Notify that a entity has been removed from the collection. @@ -881,11 +883,13 @@ def _instrument_membership_mutator(method, before, argument, after): def __set(collection, item, _sa_initiator=None): """Run set events, may eventually be inlined into decorators.""" + if _sa_initiator is not False and item is not None: executor = getattr(collection, '_sa_adapter', None) if executor: - getattr(executor, 'fire_append_event')(item, _sa_initiator) - + item = getattr(executor, 'fire_append_event')(item, _sa_initiator) + return item + def __del(collection, item, _sa_initiator=None): """Run del events, may eventually be inlined into decorators.""" if _sa_initiator is not False and item is not None: @@ -908,7 +912,7 @@ def _list_decorators(): def append(fn): def append(self, item, _sa_initiator=None): - __set(self, item, _sa_initiator) + item = __set(self, item, _sa_initiator) fn(self, item) _tidy(append) return append @@ -924,7 +928,7 @@ def _list_decorators(): def insert(fn): def insert(self, index, value): - __set(self, value) + value = __set(self, value) fn(self, index, value) _tidy(insert) return insert @@ -935,7 +939,7 @@ def _list_decorators(): existing = self[index] if existing is not None: __del(self, existing) - __set(self, value) + value = __set(self, value) fn(self, index, value) else: # slice assignment requires __delitem__, insert, __len__ @@ -985,8 +989,7 @@ def _list_decorators(): def __setslice__(self, start, end, values): for value in self[start:end]: __del(self, value) - for value in values: - __set(self, value) + values = [__set(self, value) for value in values] fn(self, start, end, values) _tidy(__setslice__) return __setslice__ @@ -1047,7 +1050,7 @@ def _dict_decorators(): def __setitem__(self, key, value, _sa_initiator=None): if key in self: __del(self, self[key], _sa_initiator) - __set(self, value, _sa_initiator) + value = __set(self, value, _sa_initiator) fn(self, key, value) _tidy(__setitem__) return __setitem__ @@ -1154,7 +1157,7 @@ def _set_decorators(): def add(fn): def add(self, value, _sa_initiator=None): if value not in self: - __set(self, value, _sa_initiator) + value = __set(self, value, _sa_initiator) # testlib.pragma exempt:__hash__ fn(self, value) _tidy(add) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 6dd2225c88..495d22be16 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -705,18 +705,35 @@ class AttributeExtension(object): """An event handler for individual attribute change events. AttributeExtension is assembled within the descriptors associated - with a mapped class. + with a mapped class. """ def append(self, state, value, initiator): - pass + """Receive a collection append event. + + The returned value will be used as the actual value to be + appended. + + """ + return value def remove(self, state, value, initiator): + """Receive a remove event. + + No return value is defined. + + """ pass def set(self, state, value, oldvalue, initiator): - pass + """Receive a set event. + + The returned value will be used as the actual value to be + set. + + """ + return value class StrategizedOption(PropertyOption): diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index f4d2b51bd9..67a886306c 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -46,7 +46,8 @@ class UOWEventHandler(interfaces.AttributeExtension): prop = _state_mapper(state).get_property(self.key) if prop.cascade.save_update and item not in sess: sess.save_or_update(item) - + return item + def remove(self, state, item, initiator): sess = _state_session(state) if sess: @@ -60,7 +61,7 @@ class UOWEventHandler(interfaces.AttributeExtension): def set(self, state, newvalue, oldvalue, initiator): # process "save_update" cascade rules for when an instance is attached to another instance if oldvalue is newvalue: - return + return newvalue sess = _state_session(state) if sess: prop = _state_mapper(state).get_property(self.key) @@ -68,7 +69,7 @@ class UOWEventHandler(interfaces.AttributeExtension): sess.save_or_update(newvalue) if prop.cascade.delete_orphan and oldvalue in sess.new: sess.expunge(oldvalue) - + return newvalue def register_attribute(class_, key, *args, **kwargs): """overrides attributes.register_attribute() to add UOW event handlers diff --git a/test/orm/attributes.py b/test/orm/attributes.py index 4e77935d47..3fe2294c49 100644 --- a/test/orm/attributes.py +++ b/test/orm/attributes.py @@ -214,6 +214,7 @@ class AttributesTest(_base.ORMTest): def set(self, state, child, oldchild, initiator): results.append(("set", state.obj(), child, oldchild)) + return child attributes.register_class(Foo) attributes.register_attribute(Foo, 'x', uselist=False, mutable_scalars=False, useobject=False, extension=ReceiveEvents()) @@ -1250,6 +1251,47 @@ class HistoryTest(_base.ORMTest): assert f.bar is None eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], (), [bar1])) +class ListenerTest(_base.ORMTest): + def test_receive_changes(self): + """test that Listeners can mutate the given value. + + This is a rudimentary test which would be better suited by a full-blown inclusion + into collection.py. + + """ + class Foo(object): + pass + class Bar(object): + pass + + class AlteringListener(AttributeExtension): + def append(self, state, child, initiator): + b2 = Bar() + b2.data = b1.data + " appended" + return b2 + + def set(self, state, value, oldvalue, initiator): + return value + " modified" + + attributes.register_class(Foo) + attributes.register_class(Bar) + attributes.register_attribute(Foo, 'data', uselist=False, useobject=False, extension=AlteringListener()) + attributes.register_attribute(Foo, 'barlist', uselist=True, useobject=True, extension=AlteringListener()) + attributes.register_attribute(Foo, 'barset', typecallable=set, uselist=True, useobject=True, extension=AlteringListener()) + attributes.register_attribute(Bar, 'data', uselist=False, useobject=False) + + f1 = Foo() + f1.data = "some data" + eq_(f1.data, "some data modified") + b1 = Bar() + b1.data = "some bar" + f1.barlist.append(b1) + assert b1.data == "some bar" + assert f1.barlist[0].data == "some bar appended" + + f1.barset.add(b1) + assert f1.barset.pop().data == "some bar appended" + if __name__ == "__main__": testenv.main() diff --git a/test/orm/collection.py b/test/orm/collection.py index fd0f389092..0d85848733 100644 --- a/test/orm/collection.py +++ b/test/orm/collection.py @@ -22,15 +22,19 @@ class Canary(sa.orm.interfaces.AttributeExtension): assert value not in self.added self.data.add(value) self.added.add(value) + return value def remove(self, obj, value, initiator): assert value not in self.removed self.data.remove(value) self.removed.add(value) def set(self, obj, value, oldvalue, initiator): + if isinstance(value, str): + value = CollectionsTest.entity_maker() + if oldvalue is not None: self.remove(obj, oldvalue, None) self.append(obj, value, None) - + return value class CollectionsTest(_base.ORMTest): class Entity(object):