from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy import event
+from sqlalchemy.orm import mapper
+from sqlalchemy.util import memoized_property
import weakref
class TrackMutationsMixin(object):
events to a parent object.
"""
- _key = None
- _parent = None
-
- def _set_parent(self, parent, key):
- self._parent = weakref.ref(parent)
- self._key = key
+ @memoized_property
+ def _parents(self):
+ """Dictionary of parent object->attribute name on the parent."""
- def _remove_parent(self):
- del self._parent
+ return weakref.WeakKeyDictionary()
- def on_change(self, key=None):
+ def on_change(self):
"""Subclasses should call this method whenever change events occur."""
- if key is None:
- key = self._key
- if self._parent:
- p = self._parent()
- if p:
- flag_modified(p, self._key)
+ for parent, key in self._parents.items():
+ flag_modified(parent, key)
@classmethod
- def listen(cls, attribute):
- """Establish this type as a mutation listener for the given class and
- attribute name.
+ def associate_with_attribute(cls, attribute):
+ """Establish this type as a mutation listener for the given
+ mapped descriptor.
"""
key = attribute.key
parent_cls = attribute.class_
def on_load(state):
+ """Listen for objects loaded or refreshed.
+
+ Wrap the target data member's value with
+ ``TrackMutationsMixin``.
+
+ """
val = state.dict.get(key, None)
if val is not None:
val = cls(val)
state.dict[key] = val
- val._set_parent(state.obj(), key)
+ val._parents[state.obj()] = key
def on_set(target, value, oldvalue, initiator):
+ """Listen for set/replace events on the target
+ data member.
+
+ Establish a weak reference to the parent object
+ on the incoming value, remove it for the one
+ outgoing.
+
+ """
+
if not isinstance(value, cls):
value = cls(value)
- value._set_parent(target.obj(), key)
+ value._parents[target.obj()] = key
if isinstance(oldvalue, cls):
- oldvalue._remove_parent()
+ oldvalue._parents.pop(state.obj(), None)
return value
event.listen(parent_cls, 'on_load', on_load, raw=True)
event.listen(parent_cls, 'on_refresh', on_load, raw=True)
event.listen(attribute, 'on_set', on_set, raw=True, retval=True)
-
+
+ @classmethod
+ def associate_with_type(cls, type_):
+ """Associate this wrapper with all future mapped columns
+ of the given type.
+
+ This is a convenience method that calls ``associate_with_attribute`` automatically.
+
+ """
+
+ def listen_for_type(mapper, class_):
+ for prop in mapper.iterate_properties:
+ if hasattr(prop, 'columns') and isinstance(prop.columns[0].type, type_):
+ cls.listen(getattr(class_, prop.key))
+
+ event.listen(mapper, 'on_mapper_configured', listen_for_type)
+
+
if __name__ == '__main__':
from sqlalchemy import Column, Integer, VARCHAR, create_engine
from sqlalchemy.orm import Session
if value is not None:
value = simplejson.loads(value, use_decimal=True)
return value
-
+
class MutationDict(TrackMutationsMixin, dict):
def __init__(self, other):
self.update(other)
def __delitem__(self, key):
dict.__delitem__(self, key)
self.on_change()
-
+
+ MutationDict.associate_with_type(JSONEncodedDict)
+
Base = declarative_base()
class Foo(Base):
__tablename__ = 'foo'
id = Column(Integer, primary_key=True)
data = Column(JSONEncodedDict)
-
- MutationDict.listen(Foo.data)
-
+
e = create_engine('sqlite://', echo=True)
Base.metadata.create_all(e)
event.Events.listen(target, identifier, fn)
def on_instrument_class(self, mapper, class_):
- """Receive a class when the mapper is first constructed, and has
- applied instrumentation to the mapped class.
+ """Receive a class when the mapper is first constructed,
+ before instrumentation is applied to the mapped class.
+
+ This event is the earliest phase of mapper construction.
+ Most attributes of the mapper are not yet initialized.
This listener can generally only be applied to the :class:`.Mapper`
class overall.
:param class\_: the mapped class.
"""
+
+ def on_mapper_configured(self, mapper, class_):
+ """Called when the mapper for the class is fully configured.
+ This event is the latest phase of mapper construction.
+ The mapper should be in its final state.
+
+ :param mapper: the :class:`.Mapper` which is the target
+ of this event.
+ :param class\_: the mapped class.
+
+ """
+ # TODO: need coverage for this event
+
def on_translate_row(self, mapper, context, row):
"""Perform pre-processing on the given result row and return a
new row instance.
or "append" event.
"""
+
+ @classmethod
+ def accept_with(cls, target):
+ from sqlalchemy.orm import interfaces
+ # TODO: coverage
+ if isinstance(target, interfaces.MapperProperty):
+ return getattr(target.parent.class_, target.key)
+ else:
+ return target
@classmethod
def listen(cls, target, identifier, fn, active_history=False,