]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- AttributeListener has been refined such that the event
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 2 Sep 2008 17:57:35 +0000 (17:57 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 2 Sep 2008 17:57:35 +0000 (17:57 +0000)
is fired before the mutation actually occurs.  Addtionally,
the append() and set() methods must now return the given value,
which is used as the value to be used in the mutation operation.
This allows creation of validating AttributeListeners which
raise before the action actually occurs, and which can change
the given value into something else before its used.
A new example "validate_attributes.py" shows one such recipe
for doing this.   AttributeListener helper functions are
also on the way.

CHANGES
examples/custom_attributes/listen_for_events.py
examples/custom_attributes/validate_attributes.py [new file with mode: 0644]
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/attributes.py
test/orm/collection.py

diff --git a/CHANGES b/CHANGES
index 1204a347c27121014c3363d77d1355e37f1dbbaf..5d91ac46c647fe683e7e8d27204ef156f2c14be1 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -33,6 +33,17 @@ CHANGES
       clause will appear in the WHERE clause of the query as well
       since this discrimination has multiple trigger points.
 
+    - AttributeListener has been refined such that the event
+      is fired before the mutation actually occurs.  Addtionally,
+      the append() and set() methods must now return the given value,
+      which is used as the value to be used in the mutation operation.
+      This allows creation of validating AttributeListeners which
+      raise before the action actually occurs, and which can change
+      the given value into something else before its used.
+      A new example "validate_attributes.py" shows one such recipe
+      for doing this.   AttributeListener helper functions are
+      also on the way.
+      
     - class.someprop.in_() raises NotImplementedError pending the
       implementation of "in_" for relation [ticket:1140]
 
index 71f8bbadecd32fc43ef88badcc65588b27fec381..e980e61edcb7c08ed46b2660df39a477bac7ead1 100644 (file)
@@ -7,9 +7,9 @@ from sqlalchemy.orm.interfaces import AttributeExtension, InstrumentationManager
 
 class InstallListeners(InstrumentationManager):
     def instrument_attribute(self, class_, key, inst):
-        """Add an event listener to all InstrumentedAttributes."""
+        """Add an event listener to an InstrumentedAttribute."""
         
-        inst.impl.extensions.append(AttributeListener(key))
+        inst.impl.extensions.insert(0, AttributeListener(key))
         return super(InstallListeners, self).instrument_attribute(class_, key, inst)
         
 class AttributeListener(AttributeExtension):
@@ -25,12 +25,14 @@ class AttributeListener(AttributeExtension):
     
     def append(self, state, value, initiator):
         self._report(state, value, None, "appended")
+        return value
 
     def remove(self, state, value, initiator):
         self._report(state, value, None, "removed")
 
     def set(self, state, value, oldvalue, initiator):
         self._report(state, value, oldvalue, "set")
+        return value
     
     def _report(self, state, value, oldvalue, verb):
         state.obj().receive_change_event(verb, self.key, value, oldvalue)
diff --git a/examples/custom_attributes/validate_attributes.py b/examples/custom_attributes/validate_attributes.py
new file mode 100644 (file)
index 0000000..63b2529
--- /dev/null
@@ -0,0 +1,117 @@
+"""
+Illustrates how to use AttributeExtension to create attribute validators.
+
+"""
+
+from sqlalchemy.orm.interfaces import AttributeExtension, InstrumentationManager
+
+class InstallValidators(InstrumentationManager):
+    """Searches a class for methods with a '_validates' attribute and assembles Validators."""
+    
+    def __init__(self, cls):
+        self.validators = {}
+        for k in dir(cls):
+            item = getattr(cls, k)
+            if hasattr(item, '_validates'):
+                self.validators[item._validates] = item
+                
+    def instrument_attribute(self, class_, key, inst):
+        """Add an event listener to an InstrumentedAttribute."""
+        
+        if key in self.validators:
+            inst.impl.extensions.insert(0, Validator(key, self.validators[key]))
+        return super(InstallValidators, self).instrument_attribute(class_, key, inst)
+        
+class Validator(AttributeExtension):
+    """Validates an attribute, given the key and a validation function."""
+    
+    def __init__(self, key, validator):
+        self.key = key
+        self.validator = validator
+    
+    def append(self, state, value, initiator):
+        return self.validator(state.obj(), value)
+
+    def set(self, state, value, oldvalue, initiator):
+        return self.validator(state.obj(), value)
+
+def validates(key):
+    """Mark a method as validating a named attribute."""
+    
+    def wrap(fn):
+        fn._validates = key
+        return fn
+    return wrap
+
+if __name__ == '__main__':
+
+    from sqlalchemy import *
+    from sqlalchemy.orm import *
+    from sqlalchemy.ext.declarative import declarative_base
+    import datetime
+    
+    Base = declarative_base(engine=create_engine('sqlite://', echo=True))
+    Base.__sa_instrumentation_manager__ = InstallValidators
+
+    class MyMappedClass(Base):
+        __tablename__ = "mytable"
+    
+        id = Column(Integer, primary_key=True)
+        date = Column(Date)
+        related_id = Column(Integer, ForeignKey("related.id"))
+        related = relation("Related", backref="mapped")
+
+        @validates('date')
+        def check_date(self, value):
+            if isinstance(value, str):
+                m, d, y = [int(x) for x in value.split('/')]
+                return datetime.date(y, m, d)
+            else:
+                assert isinstance(value, datetime.date)
+                return value
+        
+        @validates('related')
+        def check_related(self, value):
+            assert value.data == 'r1'
+            return value
+            
+        def __str__(self):
+            return "MyMappedClass(date=%r)" % self.date
+            
+    class Related(Base):
+        __tablename__ = "related"
+
+        id = Column(Integer, primary_key=True)
+        data = Column(String(50))
+
+        def __str__(self):
+            return "Related(data=%r)" % self.data
+    
+    Base.metadata.create_all()
+    session = sessionmaker()()
+    
+    r1 = Related(data='r1')
+    r2 = Related(data='r2')
+    m1 = MyMappedClass(date='5/2/2005', related=r1)
+    m2 = MyMappedClass(date=datetime.date(2008, 10, 15))
+    r1.mapped.append(m2)
+
+    try:
+        m1.date = "this is not a date"
+    except:
+        pass
+    assert m1.date == datetime.date(2005, 5, 2)
+    
+    try:
+        m2.related = r2
+    except:
+        pass
+    assert m2.related is r1
+    
+    session.add(m1)
+    session.commit()
+    assert session.query(MyMappedClass.date).order_by(MyMappedClass.date).all() == [
+        (datetime.date(2005, 5, 2),),
+        (datetime.date(2008, 10, 15),)
+    ]
+    
\ No newline at end of file
index ddebc563f09b20bf9a2ab860af8f9ed0bb472bb4..17fea7854f6a35a6df12e92f40ff527abdfabbbb 100644 (file)
@@ -382,8 +382,8 @@ class ScalarAttributeImpl(AttributeImpl):
         state.modified_event(self, False, old)
 
         if self.extensions:
-            del state.dict[self.key]
             self.fire_remove_event(state, old, None)
+            del state.dict[self.key]
         else:
             del state.dict[self.key]
 
@@ -403,14 +403,15 @@ class ScalarAttributeImpl(AttributeImpl):
         state.modified_event(self, False, old)
 
         if self.extensions:
+            value = self.fire_replace_event(state, value, old, initiator)
             state.dict[self.key] = value
-            self.fire_replace_event(state, value, old, initiator)
         else:
             state.dict[self.key] = value
 
     def fire_replace_event(self, state, value, previous, initiator):
         for ext in self.extensions:
-            ext.set(state, value, previous, initiator or self)
+            value = ext.set(state, value, previous, initiator or self)
+        return value
 
     def fire_remove_event(self, state, value, initiator):
         for ext in self.extensions:
@@ -457,8 +458,8 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl):
 
         if self.extensions:
             old = self.get(state)
+            value = self.fire_replace_event(state, value, old, initiator)
             state.dict[self.key] = value
-            self.fire_replace_event(state, value, old, initiator)
         else:
             state.dict[self.key] = value
 
@@ -483,9 +484,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
 
     def delete(self, state):
         old = self.get(state)
-        # TODO: catch key errors, convert to attributeerror?
-        del state.dict[self.key]
         self.fire_remove_event(state, old, self)
+        del state.dict[self.key]
 
     def get_history(self, state, passive=False):
         if self.key in state.dict:
@@ -510,8 +510,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
 
         # may want to add options to allow the get() here to be passive
         old = self.get(state)
+        value = self.fire_replace_event(state, value, old, initiator)
         state.dict[self.key] = value
-        self.fire_replace_event(state, value, old, initiator)
 
     def fire_remove_event(self, state, value, initiator):
         state.modified_event(self, False, value)
@@ -532,7 +532,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
                 self.sethasparent(instance_state(previous), False)
 
         for ext in self.extensions:
-            ext.set(state, value, previous, initiator or self)
+            value = ext.set(state, value, previous, initiator or self)
+        return value
 
 
 class CollectionAttributeImpl(AttributeImpl):
@@ -582,7 +583,8 @@ class CollectionAttributeImpl(AttributeImpl):
             self.sethasparent(instance_state(value), True)
 
         for ext in self.extensions:
-            ext.append(state, value, initiator or self)
+            value = ext.append(state, value, initiator or self)
+        return value
 
     def fire_pre_remove_event(self, state, initiator):
         state.modified_event(self, True, NEVER_SET, passive=True)
@@ -624,8 +626,8 @@ class CollectionAttributeImpl(AttributeImpl):
 
         collection = self.get_collection(state, passive=passive)
         if collection is PASSIVE_NORESULT:
+            value = self.fire_append_event(state, value, initiator)
             state.get_pending(self.key).append(value)
-            self.fire_append_event(state, value, initiator)
         else:
             collection.append_with_event(value, initiator)
 
@@ -635,8 +637,8 @@ class CollectionAttributeImpl(AttributeImpl):
 
         collection = self.get_collection(state, passive=passive)
         if collection is PASSIVE_NORESULT:
-            state.get_pending(self.key).remove(value)
             self.fire_remove_event(state, value, initiator)
+            state.get_pending(self.key).remove(value)
         else:
             collection.remove_with_event(value, initiator)
 
@@ -745,7 +747,7 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
 
     def set(self, state, child, oldchild, initiator):
         if oldchild is child:
-            return
+            return child
         if oldchild is not None:
             # With lazy=None, there's no guarantee that the full collection is
             # present when updating via a backref.
@@ -758,11 +760,13 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
         if child is not None:
             new_state = instance_state(child)
             new_state.get_impl(self.key).append(new_state, state.obj(), initiator, passive=True)
-
+        return child
+        
     def append(self, state, child, initiator):
         child_state = instance_state(child)
         child_state.get_impl(self.key).append(child_state, state.obj(), initiator, passive=True)
-
+        return child
+        
     def remove(self, state, child, initiator):
         if child is not None:
             child_state = instance_state(child)
index f8570dd5fc1d8f4469aefc5c090fa230a55b6511..497ef5941162d31b74088d8f7e4f293b4083e3ef 100644 (file)
@@ -584,7 +584,9 @@ class CollectionAdapter(object):
 
         """
         if initiator is not False and item is not None:
-            self.attr.fire_append_event(self.owner_state, item, initiator)
+            return self.attr.fire_append_event(self.owner_state, item, initiator)
+        else:
+            return item
 
     def fire_remove_event(self, item, initiator=None):
         """Notify that a entity has been removed from the collection.
@@ -881,11 +883,13 @@ def _instrument_membership_mutator(method, before, argument, after):
 
 def __set(collection, item, _sa_initiator=None):
     """Run set events, may eventually be inlined into decorators."""
+
     if _sa_initiator is not False and item is not None:
         executor = getattr(collection, '_sa_adapter', None)
         if executor:
-            getattr(executor, 'fire_append_event')(item, _sa_initiator)
-
+            item = getattr(executor, 'fire_append_event')(item, _sa_initiator)
+    return item
+    
 def __del(collection, item, _sa_initiator=None):
     """Run del events, may eventually be inlined into decorators."""
     if _sa_initiator is not False and item is not None:
@@ -908,7 +912,7 @@ def _list_decorators():
 
     def append(fn):
         def append(self, item, _sa_initiator=None):
-            __set(self, item, _sa_initiator)
+            item = __set(self, item, _sa_initiator)
             fn(self, item)
         _tidy(append)
         return append
@@ -924,7 +928,7 @@ def _list_decorators():
 
     def insert(fn):
         def insert(self, index, value):
-            __set(self, value)
+            value = __set(self, value)
             fn(self, index, value)
         _tidy(insert)
         return insert
@@ -935,7 +939,7 @@ def _list_decorators():
                 existing = self[index]
                 if existing is not None:
                     __del(self, existing)
-                __set(self, value)
+                value = __set(self, value)
                 fn(self, index, value)
             else:
                 # slice assignment requires __delitem__, insert, __len__
@@ -985,8 +989,7 @@ def _list_decorators():
         def __setslice__(self, start, end, values):
             for value in self[start:end]:
                 __del(self, value)
-            for value in values:
-                __set(self, value)
+            values = [__set(self, value) for value in values]
             fn(self, start, end, values)
         _tidy(__setslice__)
         return __setslice__
@@ -1047,7 +1050,7 @@ def _dict_decorators():
         def __setitem__(self, key, value, _sa_initiator=None):
             if key in self:
                 __del(self, self[key], _sa_initiator)
-            __set(self, value, _sa_initiator)
+            value = __set(self, value, _sa_initiator)
             fn(self, key, value)
         _tidy(__setitem__)
         return __setitem__
@@ -1154,7 +1157,7 @@ def _set_decorators():
     def add(fn):
         def add(self, value, _sa_initiator=None):
             if value not in self:
-                __set(self, value, _sa_initiator)
+                value = __set(self, value, _sa_initiator)
             # testlib.pragma exempt:__hash__
             fn(self, value)
         _tidy(add)
index 6dd2225c881a16c213eb0dddb9a81090deb05079..495d22be168bd2cee6c8663fad9dd250c1305943 100644 (file)
@@ -705,18 +705,35 @@ class AttributeExtension(object):
     """An event handler for individual attribute change events.
     
     AttributeExtension is assembled within the descriptors associated 
-    with a mapped class.
+    with a mapped class. 
     
     """
 
     def append(self, state, value, initiator):
-        pass
+        """Receive a collection append event.
+        
+        The returned value will be used as the actual value to be
+        appended.
+        
+        """
+        return value
 
     def remove(self, state, value, initiator):
+        """Receive a remove event.
+        
+        No return value is defined.
+        
+        """
         pass
 
     def set(self, state, value, oldvalue, initiator):
-        pass
+        """Receive a set event.
+        
+        The returned value will be used as the actual value to be
+        set.
+        
+        """
+        return value
 
 
 class StrategizedOption(PropertyOption):
index f4d2b51bd9c9451e34a12561cfdf61db5c02117f..67a886306cdf48598bc97606f465432abe7ca6eb 100644 (file)
@@ -46,7 +46,8 @@ class UOWEventHandler(interfaces.AttributeExtension):
             prop = _state_mapper(state).get_property(self.key)
             if prop.cascade.save_update and item not in sess:
                 sess.save_or_update(item)
-
+        return item
+        
     def remove(self, state, item, initiator):
         sess = _state_session(state)
         if sess:
@@ -60,7 +61,7 @@ class UOWEventHandler(interfaces.AttributeExtension):
     def set(self, state, newvalue, oldvalue, initiator):
         # process "save_update" cascade rules for when an instance is attached to another instance
         if oldvalue is newvalue:
-            return
+            return newvalue
         sess = _state_session(state)
         if sess:
             prop = _state_mapper(state).get_property(self.key)
@@ -68,7 +69,7 @@ class UOWEventHandler(interfaces.AttributeExtension):
                 sess.save_or_update(newvalue)
             if prop.cascade.delete_orphan and oldvalue in sess.new:
                 sess.expunge(oldvalue)
-
+        return newvalue
 
 def register_attribute(class_, key, *args, **kwargs):
     """overrides attributes.register_attribute() to add UOW event handlers
index 4e77935d47f5a5aa6778ff0fa5b923e77af0cf42..3fe2294c493627ea73c9a5eee9c88b8662ca96dd 100644 (file)
@@ -214,6 +214,7 @@ class AttributesTest(_base.ORMTest):
 
             def set(self, state, child, oldchild, initiator):
                 results.append(("set", state.obj(), child, oldchild))
+                return child
         
         attributes.register_class(Foo)
         attributes.register_attribute(Foo, 'x', uselist=False, mutable_scalars=False, useobject=False, extension=ReceiveEvents())
@@ -1250,6 +1251,47 @@ class HistoryTest(_base.ORMTest):
         assert f.bar is None
         eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], (), [bar1]))
 
+class ListenerTest(_base.ORMTest):
+    def test_receive_changes(self):
+        """test that Listeners can mutate the given value.
+        
+        This is a rudimentary test which would be better suited by a full-blown inclusion
+        into collection.py.
+        
+        """
+        class Foo(object):
+            pass
+        class Bar(object):
+            pass
+
+        class AlteringListener(AttributeExtension):
+            def append(self, state, child, initiator):
+                b2 = Bar()
+                b2.data = b1.data + " appended"
+                return b2
+
+            def set(self, state, value, oldvalue, initiator):
+                return value + " modified"
+
+        attributes.register_class(Foo)
+        attributes.register_class(Bar)
+        attributes.register_attribute(Foo, 'data', uselist=False, useobject=False, extension=AlteringListener())
+        attributes.register_attribute(Foo, 'barlist', uselist=True, useobject=True, extension=AlteringListener())
+        attributes.register_attribute(Foo, 'barset', typecallable=set, uselist=True, useobject=True, extension=AlteringListener())
+        attributes.register_attribute(Bar, 'data', uselist=False, useobject=False)
+        
+        f1 = Foo()
+        f1.data = "some data"
+        eq_(f1.data, "some data modified")
+        b1 = Bar()
+        b1.data = "some bar"
+        f1.barlist.append(b1)
+        assert b1.data == "some bar"
+        assert f1.barlist[0].data == "some bar appended"
+        
+        f1.barset.add(b1)
+        assert f1.barset.pop().data == "some bar appended"
+    
     
 if __name__ == "__main__":
     testenv.main()
index fd0f3890923393c58591cbe03ade5d796c784ab8..0d858487334113e7313220c36a6e33d4a52ccdfc 100644 (file)
@@ -22,15 +22,19 @@ class Canary(sa.orm.interfaces.AttributeExtension):
         assert value not in self.added
         self.data.add(value)
         self.added.add(value)
+        return value
     def remove(self, obj, value, initiator):
         assert value not in self.removed
         self.data.remove(value)
         self.removed.add(value)
     def set(self, obj, value, oldvalue, initiator):
+        if isinstance(value, str):
+            value = CollectionsTest.entity_maker()
+
         if oldvalue is not None:
             self.remove(obj, oldvalue, None)
         self.append(obj, value, None)
-
+        return value
 
 class CollectionsTest(_base.ORMTest):
     class Entity(object):