]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- add a new "on mapper configured" event - handy !
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Dec 2010 18:22:12 +0000 (13:22 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Dec 2010 18:22:12 +0000 (13:22 -0500)
examples/mutable_events/__init__.py
examples/mutable_events/scalars.py
lib/sqlalchemy/orm/events.py
lib/sqlalchemy/orm/mapper.py

index 9802a109b5fe3816e0fb387877796b59d1d7a7d6..813dc5abd2b3e91b4f56bd57c6a1f5574fc7a1a5 100644 (file)
@@ -26,6 +26,14 @@ Subclassing ``dict`` to provide "mutation tracking" looks like::
         id = Column(Integer, primary_key=True)
         data = Column(JSONEncodedDict)
 
-    MutationDict.listen(Foo.data)
+    MutationDict.associate_with_attribute(Foo.data)
+
+The explicit step of associating ``MutationDict`` with ``Foo.data`` can be 
+automated across a class of columns using ``associate_with_type()``::
+
+    MutationDict.associate_with_type(JSONEncodedDict)
+    
+All subsequent mappings will have the ``MutationDict`` wrapper applied to
+all attributes with ``JSONEncodedDict`` as their type.
 
 """
\ No newline at end of file
index 4d434fd54a88f5629696f6edf2fbd86c5b9df555..b4d6b350d73d8bc5593c00c5a99a556cb0a8fd02 100644 (file)
@@ -1,5 +1,7 @@
 from sqlalchemy.orm.attributes import flag_modified
 from sqlalchemy import event
+from sqlalchemy.orm import mapper
+from sqlalchemy.util import memoized_property
 import weakref
 
 class TrackMutationsMixin(object):
@@ -7,54 +9,78 @@ class TrackMutationsMixin(object):
     events to a parent object.
     
     """
-    _key = None
-    _parent = None
-    
-    def _set_parent(self, parent, key):
-        self._parent = weakref.ref(parent)
-        self._key = key
+    @memoized_property
+    def _parents(self):
+        """Dictionary of parent object->attribute name on the parent."""
         
-    def _remove_parent(self):
-        del self._parent
+        return weakref.WeakKeyDictionary()
         
-    def on_change(self, key=None):
+    def on_change(self):
         """Subclasses should call this method whenever change events occur."""
         
-        if key is None:
-            key = self._key
-        if self._parent:
-            p = self._parent()
-            if p:
-                flag_modified(p, self._key)
+        for parent, key in self._parents.items():
+            flag_modified(parent, key)
     
     @classmethod
-    def listen(cls, attribute):
-        """Establish this type as a mutation listener for the given class and 
-        attribute name.
+    def associate_with_attribute(cls, attribute):
+        """Establish this type as a mutation listener for the given 
+        mapped descriptor.
         
         """
         key = attribute.key
         parent_cls = attribute.class_
         
         def on_load(state):
+            """Listen for objects loaded or refreshed.   
+            
+            Wrap the target data member's value with 
+            ``TrackMutationsMixin``.
+            
+            """
             val = state.dict.get(key, None)
             if val is not None:
                 val = cls(val)
                 state.dict[key] = val
-                val._set_parent(state.obj(), key)
+                val._parents[state.obj()] = key
 
         def on_set(target, value, oldvalue, initiator):
+            """Listen for set/replace events on the target
+            data member.
+            
+            Establish a weak reference to the parent object
+            on the incoming value, remove it for the one 
+            outgoing.
+            
+            """
+            
             if not isinstance(value, cls):
                 value = cls(value)
-            value._set_parent(target.obj(), key)
+            value._parents[target.obj()] = key
             if isinstance(oldvalue, cls):
-                oldvalue._remove_parent()
+                oldvalue._parents.pop(state.obj(), None)
             return value
         
         event.listen(parent_cls, 'on_load', on_load, raw=True)
         event.listen(parent_cls, 'on_refresh', on_load, raw=True)
         event.listen(attribute, 'on_set', on_set, raw=True, retval=True)
-
+    
+    @classmethod
+    def associate_with_type(cls, type_):
+        """Associate this wrapper with all future mapped columns 
+        of the given type.
+        
+        This is a convenience method that calls ``associate_with_attribute`` automatically.
+        
+        """
+        
+        def listen_for_type(mapper, class_):
+            for prop in mapper.iterate_properties:
+                if hasattr(prop, 'columns') and isinstance(prop.columns[0].type, type_):
+                    cls.listen(getattr(class_, prop.key))
+                    
+        event.listen(mapper, 'on_mapper_configured', listen_for_type)
+        
+        
 if __name__ == '__main__':
     from sqlalchemy import Column, Integer, VARCHAR, create_engine
     from sqlalchemy.orm import Session
@@ -83,7 +109,7 @@ if __name__ == '__main__':
             if value is not None:
                 value = simplejson.loads(value, use_decimal=True)
             return value
-
+    
     class MutationDict(TrackMutationsMixin, dict):
         def __init__(self, other):
             self.update(other)
@@ -95,15 +121,15 @@ if __name__ == '__main__':
         def __delitem__(self, key):
             dict.__delitem__(self, key)
             self.on_change()
-
+            
+    MutationDict.associate_with_type(JSONEncodedDict)
+    
     Base = declarative_base()
     class Foo(Base):
         __tablename__ = 'foo'
         id = Column(Integer, primary_key=True)
         data = Column(JSONEncodedDict)
-
-    MutationDict.listen(Foo.data)
-
+    
     e = create_engine('sqlite://', echo=True)
 
     Base.metadata.create_all(e)
index f511b0e4031f2c7f056f801dfec298efb4787915..48b559c635828f162792c1908768989d8899f734 100644 (file)
@@ -267,8 +267,11 @@ class MapperEvents(event.Events):
             event.Events.listen(target, identifier, fn)
         
     def on_instrument_class(self, mapper, class_):
-        """Receive a class when the mapper is first constructed, and has
-        applied instrumentation to the mapped class.
+        """Receive a class when the mapper is first constructed, 
+        before instrumentation is applied to the mapped class.
+        
+        This event is the earliest phase of mapper construction.
+        Most attributes of the mapper are not yet initialized.
         
         This listener can generally only be applied to the :class:`.Mapper`
         class overall.
@@ -278,7 +281,20 @@ class MapperEvents(event.Events):
         :param class\_: the mapped class.
         
         """
+    
+    def on_mapper_configured(self, mapper, class_):
+        """Called when the mapper for the class is fully configured.
 
+        This event is the latest phase of mapper construction.
+        The mapper should be in its final state.
+        
+        :param mapper: the :class:`.Mapper` which is the target
+         of this event.
+        :param class\_: the mapped class.
+        
+        """
+        # TODO: need coverage for this event
+        
     def on_translate_row(self, mapper, context, row):
         """Perform pre-processing on the given result row and return a
         new row instance.
@@ -818,6 +834,15 @@ class AttributeEvents(event.Events):
       or "append" event.   
     
     """
+
+    @classmethod
+    def accept_with(cls, target):
+        from sqlalchemy.orm import interfaces
+        # TODO: coverage
+        if isinstance(target, interfaces.MapperProperty):
+            return getattr(target.parent.class_, target.key)
+        else:
+            return target
     
     @classmethod
     def listen(cls, target, identifier, fn, active_history=False, 
index f2bc045700722e6bd0457e3c517f5caa4c3e8005..e9271008edae1fda5bd98b109678387d3c9bd7d5 100644 (file)
@@ -2401,6 +2401,7 @@ def configure_mappers():
                     try:
                         mapper._post_configure_properties()
                         mapper._expire_memoizations()
+                        mapper.dispatch.on_mapper_configured(mapper, mapper.class_)
                     except:
                         exc = sys.exc_info()[1]
                         if not hasattr(exc, '_configure_failed'):