]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- convert built in AttributExtensions to event listener fns
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 20 Dec 2010 22:50:57 +0000 (17:50 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 20 Dec 2010 22:50:57 +0000 (17:50 -0500)
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/orm/util.py

index 63331e081363a7f7daf2124fe1ea70db15733dbe..259f6e7e74b65c8ad359ff14b5ec8fb5a0692d34 100644 (file)
@@ -31,7 +31,6 @@ from sqlalchemy.orm.interfaces import (
      )
 from sqlalchemy.orm.util import (
      AliasedClass as aliased,
-     Validator,
      join,
      object_mapper,
      outerjoin,
@@ -63,7 +62,6 @@ __all__ = (
     'InstrumentationManager',
     'MapperExtension',
     'AttributeExtension',
-    'Validator',
     'PropComparator',
     'Query',
     'Session',
index 667e69190e163d2518459ec0a6db78faa344ec71..ffabf0bdc96e80f3b80f5b9d5c443e0e6b15aa0d 100644 (file)
@@ -1147,7 +1147,7 @@ def register_attribute_impl(class_, key,
         backref_listeners(manager[key], backref, uselist)
 
     manager.post_configure_attribute(key)
-
+    return manager[key]
     
 def register_descriptor(class_, key, comparator=None, 
                                 parententity=None, property_=None, doc=None):
index 1f704f5023767fc301e65e68619fd232de8a148c..88a8a8ea661e8d6ab85ad4f364420ce33840e2d3 100644 (file)
@@ -1605,8 +1605,6 @@ class Session(object):
 
 _expire_state = state.InstanceState.expire_attributes
     
-UOWEventHandler = unitofwork.UOWEventHandler
-
 _sessions = weakref.WeakValueDictionary()
 
 def make_transient(instance):
index 21f22ef5092c855d147818c642f8636c3318e7e2..ec2c3c2bc3b5109be35bf7a1a8144efa5bf76f15 100644 (file)
@@ -8,7 +8,7 @@
    implementations, and related MapperOptions."""
 
 from sqlalchemy import exc as sa_exc
-from sqlalchemy import sql, util, log
+from sqlalchemy import sql, util, log, event
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.sql import visitors, expression, operators
 from sqlalchemy.orm import mapper, attributes, interfaces, exc as orm_exc
@@ -17,7 +17,7 @@ from sqlalchemy.orm.interfaces import (
     LoaderStrategy, StrategizedOption, MapperOption, PropertyOption,
     serialize_path, deserialize_path, StrategizedProperty
     )
-from sqlalchemy.orm import session as sessionlib
+from sqlalchemy.orm import session as sessionlib, unitofwork
 from sqlalchemy.orm import util as mapperutil
 from sqlalchemy.orm.query import Query
 import itertools
@@ -38,22 +38,36 @@ def _register_attribute(strategy, mapper, useobject,
     prop = strategy.parent_property
 
     attribute_ext = list(util.to_list(prop.extension, default=[]))
-        
+    
+    listen_hooks = []
+    
     if useobject and prop.single_parent:
-        attribute_ext.insert(0, _SingleParentValidator(prop))
+        listen_hooks.append(single_parent_validator)
 
     if prop.key in prop.parent._validators:
-        attribute_ext.insert(0, 
-            mapperutil.Validator(prop.key, prop.parent._validators[prop.key])
+        listen_hooks.append(
+            lambda desc, prop: mapperutil._validator_events(desc, 
+                                prop.key, 
+                                prop.parent._validators[prop.key])
         )
     
     if useobject:
-        attribute_ext.append(sessionlib.UOWEventHandler(prop.key))
+        listen_hooks.append(unitofwork.track_cascade_events)
     
+    # need to assemble backref listeners
+    # after the singleparentvalidator, mapper validator
+    backref = kw.pop('backref', None)
+    if backref:
+        listen_hooks.append(
+            lambda desc, prop: attributes.backref_listeners(desc, 
+                                backref, 
+                                uselist)
+        )
+        
     for m in mapper.self_and_descendants:
         if prop is m._props.get(prop.key):
             
-            attributes.register_attribute_impl(
+            desc = attributes.register_attribute_impl(
                 m.class_, 
                 prop.key, 
                 parent_token=prop,
@@ -71,6 +85,9 @@ def _register_attribute(strategy, mapper, useobject,
                 doc=prop.doc,
                 **kw
                 )
+            
+            for hook in listen_hooks:
+                hook(desc, prop)
 
 class UninstrumentedColumnLoader(LoaderStrategy):
     """Represent the a non-instrumented MapperProperty.
@@ -1237,11 +1254,8 @@ class LoadEagerFromAliasOption(PropertyOption):
                         ("user_defined_eager_row_processor", 
                         interfaces._reduce_path(paths[-1]))] = adapter
 
-class _SingleParentValidator(interfaces.AttributeExtension):
-    def __init__(self, prop):
-        self.prop = prop
-
-    def _do_check(self, state, value, oldvalue, initiator):
+def single_parent_validator(desc, prop):
+    def _do_check(state, value, oldvalue, initiator):
         if value is not None:
             hasparent = initiator.hasparent(attributes.instance_state(value))
             if hasparent and oldvalue is not value: 
@@ -1249,14 +1263,16 @@ class _SingleParentValidator(interfaces.AttributeExtension):
                     "Instance %s is already associated with an instance "
                     "of %s via its %s attribute, and is only allowed a "
                     "single parent." % 
-                    (mapperutil.instance_str(value), state.class_, self.prop)
+                    (mapperutil.instance_str(value), state.class_, prop)
                 )
         return value
         
-    def append(self, state, value, initiator):
-        return self._do_check(state, value, None, initiator)
-
-    def set(self, state, value, oldvalue, initiator):
-        return self._do_check(state, value, oldvalue, initiator)
-
+    def append(state, value, initiator):
+        return _do_check(state, value, None, initiator)
 
+    def set_(state, value, oldvalue, initiator):
+        return _do_check(state, value, oldvalue, initiator)
+    
+    event.listen(desc, 'on_append', append, raw=True, retval=True, active_history=True)
+    event.listen(desc, 'on_set', set_, raw=True, retval=True, active_history=True)
+    
index ba43b13592a9d935936e6ed8cd9c310fe43a80d6..0dd5640a888002c35e6805610f4bfbce1a8529e4 100644 (file)
@@ -12,42 +12,37 @@ organizes them in order of dependency, and executes.
 
 """
 
-from sqlalchemy import util
+from sqlalchemy import util, event
 from sqlalchemy.util import topological
 from sqlalchemy.orm import attributes, interfaces
 from sqlalchemy.orm import util as mapperutil
 session = util.importlater("sqlalchemy.orm", "session")
 
-class UOWEventHandler(interfaces.AttributeExtension):
-    """An event handler added to all relationship attributes which handles
-    session cascade operations.
-    """
-    
-    active_history = False
+def track_cascade_events(descriptor, prop):
+    """Establish event listeners on object attributes which handle
+    cascade-on-set/append.
     
-    def __init__(self, key):
-        self.key = key
-        
-    # TODO: migrate these to unwrapped events
+    """
+    key = prop.key
     
-    def append(self, state, item, initiator):
+    def append(state, item, initiator):
         # process "save_update" cascade rules for when 
         # an instance is appended to the list of another instance
 
         sess = session._state_session(state)
         if sess:
-            prop = state.manager.mapper._props[self.key]
+            prop = state.manager.mapper._props[key]
             item_state = attributes.instance_state(item)
             if prop.cascade.save_update and \
-                (prop.cascade_backrefs or self.key == initiator.key) and \
+                (prop.cascade_backrefs or key == initiator.key) and \
                 not sess._contains_state(item_state):
                 sess._save_or_update_state(item_state)
         return item
         
-    def remove(self, state, item, initiator):
+    def remove(state, item, initiator):
         sess = session._state_session(state)
         if sess:
-            prop = state.manager.mapper._props[self.key]
+            prop = state.manager.mapper._props[key]
             # expunge pending orphans
             item_state = attributes.instance_state(item)
             if prop.cascade.delete_orphan and \
@@ -55,7 +50,7 @@ class UOWEventHandler(interfaces.AttributeExtension):
                 prop.mapper._is_orphan(item_state):
                     sess.expunge(item)
 
-    def set(self, state, newvalue, oldvalue, initiator):
+    def set_(state, newvalue, oldvalue, initiator):
         # process "save_update" cascade rules for when an instance 
         # is attached to another instance
         if oldvalue is newvalue:
@@ -63,11 +58,11 @@ class UOWEventHandler(interfaces.AttributeExtension):
 
         sess = session._state_session(state)
         if sess:
-            prop = state.manager.mapper._props[self.key]
+            prop = state.manager.mapper._props[key]
             if newvalue is not None:
                 newvalue_state = attributes.instance_state(newvalue)
                 if prop.cascade.save_update and \
-                    (prop.cascade_backrefs or self.key == initiator.key) and \
+                    (prop.cascade_backrefs or key == initiator.key) and \
                     not sess._contains_state(newvalue_state):
                     sess._save_or_update_state(newvalue_state)
             
@@ -78,6 +73,10 @@ class UOWEventHandler(interfaces.AttributeExtension):
                     prop.mapper._is_orphan(oldvalue_state):
                     sess.expunge(oldvalue)
         return newvalue
+        
+    event.listen(descriptor, 'on_append', append, raw=True, retval=True)
+    event.listen(descriptor, 'on_remove', remove, raw=True, retval=True)
+    event.listen(descriptor, 'on_set', set_, raw=True, retval=True)
 
 
 class UOWTransaction(object):
index b5fa0c0cff01a2eebea8b4a2b0c02d7484c23cc7..52e250239bc9e1dff6ed451bdef3564214356966 100644 (file)
@@ -6,11 +6,10 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 import sqlalchemy.exceptions as sa_exc
-from sqlalchemy import sql, util
+from sqlalchemy import sql, util, event
 from sqlalchemy.sql import expression, util as sql_util, operators
 from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE,\
-                                PropComparator, MapperProperty,\
-                                AttributeExtension
+                                PropComparator, MapperProperty
 from sqlalchemy.orm import attributes, exc
 import operator
 
@@ -55,37 +54,18 @@ class CascadeOptions(dict):
                          'delete_orphan', 'refresh-expire']
              if getattr(self, x, False) is True]))
 
+def _validator_events(desc, key, validator):
+    """Runs a validation method on an attribute value to be set or appended."""
 
-class Validator(AttributeExtension):
-    """Runs a validation method on an attribute value to be set or appended.
-
-    The Validator class is used by the :func:`~sqlalchemy.orm.validates`
-    decorator, and direct access is usually not needed.
-
-    """
-
-    def __init__(self, key, validator):
-        """Construct a new Validator.
-
-            key - name of the attribute to be validated;
-            will be passed as the second argument to
-            the validation method (the first is the object instance itself).
-
-            validator - an function or instance method which accepts
-            three arguments; an instance (usually just 'self' for a method),
-            the key name of the attribute, and the value.  The function should
-            return the same value given, unless it wishes to modify it.
-
-        """
-        self.key = key
-        self.validator = validator
-
-    def append(self, state, value, initiator):
-        return self.validator(state.obj(), self.key, value)
-
-    def set(self, state, value, oldvalue, initiator):
-        return self.validator(state.obj(), self.key, value)
+    def append(state, value, initiator):
+        return validator(state.obj(), key, value)
 
+    def set_(state, value, oldvalue, initiator):
+        return validator(state.obj(), key, value)
+    
+    event.listen(desc, 'on_append', append, raw=True, retval=True)
+    event.listen(desc, 'on_set', set_, raw=True, retval=True)
+    
 def polymorphic_union(table_map, typecolname, aliasname='p_union'):
     """Create a ``UNION`` statement used by a polymorphic mapper.