]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
merged current entity_management brach r3457-r3462. cleans up
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 5 Sep 2007 17:25:32 +0000 (17:25 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 5 Sep 2007 17:25:32 +0000 (17:25 +0000)
'_state' mamangement in attributes, moves __init__() instrumntation into attributes.py,
and reduces method call overhead by removing '_state' property.
future enhancements may include _state maintaining a weakref to the instance and a
strong ref to its __dict__ so that garbage-collected instances can get added to 'dirty',
when weak-referenced identity map is used.

lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/dynamic.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/session.py
test/orm/attributes.py
test/orm/collection.py
test/perf/masseagerload.py

index f369c53963e1586a87c4fdbd657a6855c42304fe..7290e2ac23e5aeee91a52cbb5e51a3f07b2a7eac 100644 (file)
@@ -84,13 +84,13 @@ class InstrumentedAttribute(interfaces.PropComparator):
         return self.get(obj)
 
     def commit_to_state(self, state, obj, value=NO_VALUE):
-        """commit the a copy of thte value of 'obj' to the given CommittedState"""
-
+        """commit the object's current state to its 'committed' state."""
+        
         if value is NO_VALUE:
             if self.key in obj.__dict__:
                 value = obj.__dict__[self.key]
         if value is not NO_VALUE:
-            state.data[self.key] = self.copy(value)
+            state.committed_state[self.key] = self.copy(value)
 
     def clause_element(self):
         return self.comparator.clause_element()
@@ -119,7 +119,7 @@ class InstrumentedAttribute(interfaces.PropComparator):
         will also not have a `hasparent` flag.
         """
 
-        return item._state.get(('hasparent', id(self)), optimistic)
+        return item._state.parents.get(id(self), optimistic)
 
     def sethasparent(self, item, value):
         """Set a boolean flag on the given item corresponding to
@@ -127,7 +127,7 @@ class InstrumentedAttribute(interfaces.PropComparator):
         attribute represented by this ``InstrumentedAttribute``.
         """
 
-        item._state[('hasparent', id(self))] = value
+        item._state.parents[id(self)] = value
 
     def get_history(self, obj, passive=False):
         """Return a new ``AttributeHistory`` object for the given object/this attribute's key.
@@ -165,11 +165,11 @@ class InstrumentedAttribute(interfaces.PropComparator):
         if callable_ is None:
             self.initialize(obj)
         else:
-            obj._state[('callable', self)] = callable_
+            obj._state.callables[self] = callable_
 
     def _get_callable(self, obj):
-        if ('callable', self) in obj._state:
-            return obj._state[('callable', self)]
+        if self in obj._state.callables:
+            return obj._state.callables[self]
         elif self.callable_ is not None:
             return self.callable_(obj)
         else:
@@ -183,7 +183,7 @@ class InstrumentedAttribute(interfaces.PropComparator):
         """
 
         try:
-            del obj._state[('callable', self)]
+            del obj._state.callables[self]
         except KeyError:
             pass
         self.clear(obj)
@@ -223,10 +223,8 @@ class InstrumentedAttribute(interfaces.PropComparator):
             state = obj._state
             # if an instance-wide "trigger" was set, call that
             # and start again
-            if 'trigger' in state:
-                trig = state['trigger']
-                del state['trigger']
-                trig()
+            if state.trigger:
+                state.call_trigger()
                 return self.get(obj, passive=passive)
 
             callable_ = self._get_callable(obj)
@@ -265,11 +263,10 @@ class InstrumentedAttribute(interfaces.PropComparator):
         """
 
         state = obj._state
-        orig = state.get('original', None)
-        if orig is not None:
-            self.commit_to_state(orig, obj, value)
+        if state.committed_state is not None:
+            self.commit_to_state(state, obj, value)
         # remove per-instance callable, if any
-        state.pop(('callable', self), None)
+        state.callables.pop(self, None)
         obj.__dict__[self.key] = value
         return value
 
@@ -278,21 +275,21 @@ class InstrumentedAttribute(interfaces.PropComparator):
         return value
 
     def fire_append_event(self, obj, value, initiator):
-        obj._state['modified'] = True
+        obj._state.modified = True
         if self.trackparent and value is not None:
             self.sethasparent(value, True)
         for ext in self.extensions:
             ext.append(obj, value, initiator or self)
 
     def fire_remove_event(self, obj, value, initiator):
-        obj._state['modified'] = True
+        obj._state.modified = True
         if self.trackparent and value is not None:
             self.sethasparent(value, False)
         for ext in self.extensions:
             ext.remove(obj, value, initiator or self)
 
     def fire_replace_event(self, obj, value, previous, initiator):
-        obj._state['modified'] = True
+        obj._state.modified = True
         if self.trackparent:
             if value is not None:
                 self.sethasparent(value, True)
@@ -334,7 +331,7 @@ class InstrumentedScalarAttribute(InstrumentedAttribute):
         if self.mutable_scalars:
             h = self.get_history(obj, passive=True)
             if h is not None and h.is_modified():
-                obj._state['modified'] = True
+                obj._state.modified = True
                 return True
             else:
                 return False
@@ -354,10 +351,8 @@ class InstrumentedScalarAttribute(InstrumentedAttribute):
 
         state = obj._state
         # if an instance-wide "trigger" was set, call that
-        if 'trigger' in state:
-            trig = state['trigger']
-            del state['trigger']
-            trig()
+        if state.trigger:
+            state.call_trigger()
 
         old = self.get(obj)
         obj.__dict__[self.key] = value
@@ -415,7 +410,7 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute):
         if self.key not in obj.__dict__:
             return
 
-        obj._state['modified'] = True
+        obj._state.modified = True
 
         collection = self.get_collection(obj)
         collection.clear_with_event()
@@ -453,10 +448,8 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute):
 
         state = obj._state
         # if an instance-wide "trigger" was set, call that
-        if 'trigger' in state:
-            trig = state['trigger']
-            del state['trigger']
-            trig()
+        if state.trigger:
+            state.call_trigger()
 
         old = self.get(obj)
         old_collection = self.get_collection(obj, old)
@@ -466,7 +459,7 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute):
                               collection=new_collection)
 
         obj.__dict__[self.key] = user_data
-        state['modified'] = True
+        state.modified = True
 
         # mark all the old elements as detached from the parent
         if old_collection:
@@ -477,17 +470,16 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute):
         """Set an attribute value on the given instance and 'commit' it."""
         
         state = obj._state
-        orig = state.get('original', None)
 
         collection, user_data = self._build_collection(obj)
         self._load_collection(obj, value or [], emit_events=False,
                               collection=collection)
         value = user_data
 
-        if orig is not None:
-            self.commit_to_state(orig, obj, value)
+        if state.committed_state is not None:
+            self.commit_to_state(state, obj, value)
         # remove per-instance callable, if any
-        state.pop(('callable', self), None)
+        state.callables.pop(self, None)
         obj.__dict__[self.key] = value
         return value
 
@@ -543,38 +535,57 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
     def remove(self, obj, child, initiator):
         getattr(child.__class__, self.key).remove(child, obj, initiator)
 
-class CommittedState(object):
-    """Store the original state of an object when the ``commit()`
-    method on the attribute manager is called.
-    """
-
-
-    def __init__(self, manager, obj):
-        self.data = {}
+class InstanceState(object):
+    """tracks state information at the instance level."""
+    
+    def __init__(self, obj):
+        self.committed_state = None
+        self.modified = False
+        self.trigger = None
+        self.callables = {}
+        self.parents = {}
+    
+    def __getstate__(self):
+        return {'committed_state':self.committed_state, 'parents':self.parents, 'modified':self.modified}
+    
+    def __setstate__(self, state):
+        self.committed_state = state['committed_state']
+        self.parents = state['parents']
+        self.modified = state['modified']
+        self.callables = {}
+        self.trigger = None
+        
+    def call_trigger(self):
+        trig = self.trigger
+        self.trigger = None
+        trig()
+        
+    def commit(self, manager, obj):
+        self.committed_state = {}
+        self.modified = False
         for attr in manager.managed_attributes(obj.__class__):
             attr.commit_to_state(self, obj)
 
     def rollback(self, manager, obj):
-        for attr in manager.managed_attributes(obj.__class__):
-            if attr.key in self.data:
-                if not hasattr(attr, 'get_collection'):
-                    obj.__dict__[attr.key] = self.data[attr.key]
+        if not self.committed_state:
+            manager._clear(obj)
+        else:
+            for attr in manager.managed_attributes(obj.__class__):
+                if attr.key in self.committed_state:
+                    if not hasattr(attr, 'get_collection'):
+                        obj.__dict__[attr.key] = self.committed_state[attr.key]
+                    else:
+                        collection = attr.get_collection(obj)
+                        collection.clear_without_event()
+                        for item in self.committed_state[attr.key]:
+                            collection.append_without_event(item)
                 else:
-                    collection = attr.get_collection(obj)
-                    collection.clear_without_event()
-                    for item in self.data[attr.key]:
-                        collection.append_without_event(item)
-            else:
-                if attr.key in obj.__dict__:
-                    del obj.__dict__[attr.key]
-
-    def __repr__(self):
-        return "CommittedState: %s" % repr(self.data)
+                    if attr.key in obj.__dict__:
+                        del obj.__dict__[attr.key]
 
 class AttributeHistory(object):
     """Calculate the *history* of a particular attribute on a
-    particular instance, based on the ``CommittedState`` associated
-    with the instance, if any.
+    particular instance.
     """
 
     def __init__(self, attr, obj, current, passive=False):
@@ -583,9 +594,8 @@ class AttributeHistory(object):
         # get the "original" value.  if a lazy load was fired when we got
         # the 'current' value, this "original" was also populated just
         # now as well (therefore we have to get it second)
-        orig = obj._state.get('original', None)
-        if orig is not None:
-            original = orig.data.get(attr.key)
+        if obj._state.committed_state:
+            original = obj._state.committed_state.get(attr.key, None)
         else:
             original = None
 
@@ -652,11 +662,7 @@ class AttributeManager(object):
         """
 
         for o in obj:
-            orig = o._state.get('original')
-            if orig is not None:
-                orig.rollback(self, o)
-            else:
-                self._clear(o)
+            o._state.rollback(self, o)
 
     def _clear(self, obj):
         for attr in self.managed_attributes(obj.__class__):
@@ -664,19 +670,12 @@ class AttributeManager(object):
                 del obj.__dict__[attr.key]
             except KeyError:
                 pass
-
+    
     def commit(self, *obj):
-        """Create a ``CommittedState`` instance for each object in the given list, representing
-        its *unchanged* state, and associates it with the instance.
-
-        ``AttributeHistory`` objects will indicate the modified state of
-        instance attributes as compared to its value in this
-        ``CommittedState`` object.
-        """
+        """Establish the "committed state" for each object in the given list."""
 
         for o in obj:
-            o._state['original'] = CommittedState(self, o)
-            o._state['modified'] = False
+            o._state.commit(self, o)
 
     def managed_attributes(self, class_):
         """Return a list of all ``InstrumentedAttribute`` objects
@@ -706,7 +705,7 @@ class AttributeManager(object):
         for attr in self.managed_attributes(object.__class__):
             if attr.check_mutable_modified(object):
                 return True
-        return object._state.get('modified', False)
+        return object._state.modified
 
     def get_history(self, obj, key, **kwargs):
         """Return a new ``AttributeHistory`` object for the given
@@ -743,12 +742,10 @@ class AttributeManager(object):
         removed.
         """
 
+        s = obj._state
         self._clear(obj)
-        try:
-            del obj._state['original']
-        except KeyError:
-            pass
-        obj._state['trigger'] = callable
+        s.committed_state = None
+        s.trigger = callable
 
     def untrigger_history(self, obj):
         """Remove a trigger function set by trigger_history.
@@ -756,14 +753,14 @@ class AttributeManager(object):
         Does not restore the previous state of the object.
         """
 
-        del obj._state['trigger']
+        obj._state.trigger = None
 
     def has_trigger(self, obj):
         """Return True if the given object has a trigger function set
         by ``trigger_history()``.
         """
 
-        return 'trigger' in obj._state
+        return obj._state.trigger is not None
 
     def reset_instance_attribute(self, obj, key):
         """Remove any per-instance callable functions corresponding to
@@ -774,16 +771,6 @@ class AttributeManager(object):
         attr = getattr(obj.__class__, key)
         attr.reset(obj)
 
-    def reset_class_managed(self, class_):
-        """Remove all ``InstrumentedAttribute`` property objects from
-        the given class.
-        """
-
-        for attr in self.noninherited_managed_attributes(class_):
-            delattr(class_, attr.key)
-        self._inherited_attribute_cache.pop(class_,None)
-        self._noninherited_attribute_cache.pop(class_,None)
-
     def is_class_managed(self, class_, key):
         """Return True if the given `key` correponds to an
         instrumented property on the given class.
@@ -826,7 +813,71 @@ class AttributeManager(object):
             return getattr(obj_or_cls, key)
         else:
             return getattr(obj_or_cls.__class__, key)
+    
+    def manage(self, obj):
+        if not hasattr(obj, '_state'):
+            obj._state = InstanceState(obj)
+            
+    def new_instance(self, class_):
+        """create a new instance of class_ without its __init__() method being called."""
+        
+        s = class_.__new__(class_)
+        s._state = InstanceState(s)
+        return s
+        
+    def register_class(self, class_, extra_init=None, on_exception=None):
+        """decorate the constructor of the given class to establish attribute
+        management on new instances."""
 
+        oldinit = None
+        doinit = False
+            
+        def init(instance, *args, **kwargs):
+            instance._state = InstanceState(instance)
+
+            if extra_init:
+                extra_init(class_, oldinit, instance, args, kwargs)
+
+            if doinit:
+                try:
+                    oldinit(instance, *args, **kwargs)
+                except:
+                    if on_exception:
+                        on_exception(class_, oldinit, instance, args, kwargs)
+                    raise
+        
+        # override oldinit
+        oldinit = class_.__init__
+        if oldinit is None or not hasattr(oldinit, '_oldinit'):
+            init._oldinit = oldinit
+            class_.__init__ = init
+        # if oldinit is already one of our 'init' methods, replace it
+        elif hasattr(oldinit, '_oldinit'):
+            init._oldinit = oldinit._oldinit
+            class_.__init = init
+            oldinit = oldinit._oldinit
+            
+        if oldinit is not None:
+            doinit = oldinit is not object.__init__
+            try:
+                init.__name__ = oldinit.__name__
+                init.__doc__ = oldinit.__doc__
+            except:
+                # cant set __name__ in py 2.3 !
+                pass
+            
+    def unregister_class(self, class_):
+        if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'):
+            if class_.__init__._oldinit is not None:
+                class_.__init__ = class_.__init__._oldinit
+            else:
+                delattr(class_, '__init__')
+                
+        for attr in self.noninherited_managed_attributes(class_):
+            delattr(class_, attr.key)
+        self._inherited_attribute_cache.pop(class_,None)
+        self._noninherited_attribute_cache.pop(class_,None)
+        
     def register_attribute(self, class_, key, uselist, callable_=None, **kwargs):
         """Register an attribute at the class level to be instrumented
         for all instances of the class.
@@ -837,13 +888,6 @@ class AttributeManager(object):
         self._inherited_attribute_cache.pop(class_, None)
         self._noninherited_attribute_cache.pop(class_, None)
 
-        if not hasattr(class_, '_state'):
-            def _get_state(self):
-                if not hasattr(self, '_sa_attr_state'):
-                    self._sa_attr_state = {}
-                return self._sa_attr_state
-            class_._state = property(_get_state)
-
         typecallable = kwargs.pop('typecallable', None)
         if isinstance(typecallable, InstrumentedAttribute):
             typecallable = None
index 1d4b5f6c939ac70f9cc2567e7e44813ca8df07bd..aa510515031a764a2207dbb05529c1ed58144c71 100644 (file)
@@ -34,7 +34,7 @@ class DynamicCollectionAttribute(attributes.InstrumentedAttribute):
         old_collection = self.get(obj).assign(value)
 
         # TODO: emit events ???
-        state['modified'] = True
+        state.modified = True
 
     def delete(self, *args, **kwargs):
         raise NotImplementedError()
index 960282255500a4e870c1bfa78650d96c889169c0..5d495d7a97683c66117e0c4859c3026764819065 100644 (file)
@@ -198,14 +198,9 @@ class Mapper(object):
     def dispose(self):
         # disaable any attribute-based compilation
         self.__props_init = True
-        attribute_manager.reset_class_managed(self.class_)
         if hasattr(self.class_, 'c'):
             del self.class_.c
-        if hasattr(self.class_, '__init__') and hasattr(self.class_.__init__, '_oldinit'):
-            if self.class_.__init__._oldinit is not None:
-                self.class_.__init__ = self.class_.__init__._oldinit
-            else:
-                delattr(self.class_, '__init__')
+        attribute_manager.unregister_class(self.class_)
         
     def compile(self):
         """Compile this mapper into its final internal format.
@@ -664,34 +659,14 @@ class Mapper(object):
         if not self.non_primary and (self.class_key in mapper_registry):
              raise exceptions.ArgumentError("Class '%s' already has a primary mapper defined with entity name '%s'.  Use non_primary=True to create a non primary Mapper, or to create a new primary mapper, remove this mapper first via sqlalchemy.orm.clear_mapper(mapper), or preferably sqlalchemy.orm.clear_mappers() to clear all mappers." % (self.class_, self.entity_name))
 
-        attribute_manager.reset_class_managed(self.class_)
-
-        oldinit = self.class_.__init__
-        doinit = oldinit is not None and oldinit is not object.__init__
-            
-        def init(instance, *args, **kwargs):
+        def extra_init(class_, oldinit, instance, args, kwargs):
             self.compile()
-            self.extension.init_instance(self, self.class_, oldinit, instance, args, kwargs)
-
-            if doinit:
-                try:
-                    oldinit(instance, *args, **kwargs)
-                except:
-                    # call init_failed but suppress exceptions into warnings so that original __init__ 
-                    # exception is raised
-                    util.warn_exception(self.extension.init_failed, self, self.class_, oldinit, instance, args, kwargs)
-                    raise
-
-        # override oldinit, ensuring that its not already a Mapper-decorated init method
-        if oldinit is None or not hasattr(oldinit, '_oldinit'):
-            try:
-                init.__name__ = oldinit.__name__
-                init.__doc__ = oldinit.__doc__
-            except:
-                # cant set __name__ in py 2.3 !
-                pass
-            init._oldinit = oldinit
-            self.class_.__init__ = init
+            self.extension.init_instance(self, class_, oldinit, instance, args, kwargs)
+        
+        def on_exception(class_, oldinit, instance, args, kwargs):
+            util.warn_exception(self.extension.init_failed, self, class_, oldinit, instance, args, kwargs)
+
+        attribute_manager.register_class(self.class_, extra_init=extra_init, on_exception=on_exception)
 
         _COMPILE_MUTEX.acquire()
         try:
@@ -1436,7 +1411,7 @@ class Mapper(object):
             # plugin point
             instance = extension.create_instance(self, context, row, self.class_)
             if instance is EXT_CONTINUE:
-                instance = self._create_instance(context.session)
+                instance = attribute_manager.new_instance(self.class_)
             instance._entity_name = self.entity_name
             if self.__should_log_debug:
                 self.__log_debug("_instance(): created new instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey)))
@@ -1459,12 +1434,6 @@ class Mapper(object):
         
         return instance
 
-    def _create_instance(self, session):
-        obj = self.class_.__new__(self.class_)
-        obj._entity_name = self.entity_name
-
-        return obj
-
     def _deferred_inheritance_condition(self, needs_tables):
         cond = self.inherit_condition
 
index ebd4bd3d312e46d1575d3518e54a730c912c2868..b616570ab39ca2575751281a3b3e252b53e0a458 100644 (file)
@@ -847,7 +847,7 @@ class Session(object):
         try:
             key = getattr(object, '_instance_key', None)
             if key is None:
-                merged = mapper.class_.__new__(mapper.class_)
+                merged = attribute_manager.new_instance(mapper.class_)
             else:
                 if key in self.identity_map:
                     merged = self.identity_map[key]
@@ -940,16 +940,9 @@ class Session(object):
                                                      "or is already persistent in a "
                                                      "different Session" % repr(obj))
         else:
-            m = _class_mapper(obj.__class__, entity_name=kwargs.get('entity_name', None))
-
-            # this would be a nice exception to raise...however this is incompatible with a contextual
-            # session which puts all objects into the session upon construction.
-            #if m._is_orphan(object):
-            #    raise exceptions.InvalidRequestError("Instance '%s' is an orphan, "
-            #                                         "and must be attached to a parent "
-            #                                         "object to be saved" % (repr(object)))
-
-            m._assign_entity_name(obj)
+            # TODO: consolidate the steps here
+            attribute_manager.manage(obj)
+            obj._entity_name = kwargs.get('entity_name', None)
             self._attach(obj)
             self.uow.register_new(obj)
 
index 8ca2d1b8e1a8ce1abfb9b3bb5227668ff9623df3..6314656b92325954931e3565fcffec8c13d33d62 100644 (file)
@@ -5,14 +5,17 @@ from sqlalchemy.orm.collections import collection
 from sqlalchemy import exceptions
 from testlib import *
 
+# these test classes defined at the module
+# level to support pickling
 class MyTest(object):pass
 class MyTest2(object):pass
-    
+
 class AttributesTest(PersistTest):
     """tests for the attributes.py module, which deals with tracking attribute changes on an object."""
-    def testbasic(self):
+    def test_basic(self):
         class User(object):pass
         manager = attributes.AttributeManager()
+        manager.register_class(User)
         manager.register_attribute(User, 'user_id', uselist = False)
         manager.register_attribute(User, 'user_name', uselist = False)
         manager.register_attribute(User, 'email_address', uselist = False)
@@ -39,8 +42,11 @@ class AttributesTest(PersistTest):
         print repr(u.__dict__)
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
 
-    def testpickleness(self):
+    def test_pickleness(self):
+
         manager = attributes.AttributeManager()
+        manager.register_class(MyTest)
+        manager.register_class(MyTest2)
         manager.register_attribute(MyTest, 'user_id', uselist = False)
         manager.register_attribute(MyTest, 'user_name', uselist = False)
         manager.register_attribute(MyTest, 'email_address', uselist = False)
@@ -97,10 +103,12 @@ class AttributesTest(PersistTest):
         self.assert_(o4.mt2[0].a == 'abcde')
         self.assert_(o4.mt2[0].b is None)
 
-    def testlist(self):
+    def test_list(self):
         class User(object):pass
         class Address(object):pass
         manager = attributes.AttributeManager()
+        manager.register_class(User)
+        manager.register_class(Address)
         manager.register_attribute(User, 'user_id', uselist = False)
         manager.register_attribute(User, 'user_name', uselist = False)
         manager.register_attribute(User, 'addresses', uselist = True)
@@ -138,10 +146,12 @@ class AttributesTest(PersistTest):
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
         self.assert_(len(manager.get_history(u, 'addresses').unchanged_items()) == 1)
 
-    def testbackref(self):
+    def test_backref(self):
         class Student(object):pass
         class Course(object):pass
         manager = attributes.AttributeManager()
+        manager.register_class(Student)
+        manager.register_class(Course)
         manager.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students'))
         manager.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses'))
         
@@ -166,7 +176,9 @@ class AttributesTest(PersistTest):
         self.assert_(c.students == [s2,s3])        
         class Post(object):pass
         class Blog(object):pass
-        
+
+        manager.register_class(Post)
+        manager.register_class(Blog)
         manager.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True)
         manager.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True)
         b = Blog()
@@ -190,6 +202,8 @@ class AttributesTest(PersistTest):
 
         class Port(object):pass
         class Jack(object):pass
+        manager.register_class(Port)
+        manager.register_class(Jack)
         manager.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port'))
         manager.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack'))
         p = Port()
@@ -201,13 +215,15 @@ class AttributesTest(PersistTest):
         j.port = None
         self.assert_(p.jack is None)
 
-    def testlazytrackparent(self):
+    def test_lazytrackparent(self):
         """test that the "hasparent" flag works properly when lazy loaders and backrefs are used"""
         manager = attributes.AttributeManager()
 
         class Post(object):pass
         class Blog(object):pass
-
+        manager.register_class(Post)
+        manager.register_class(Blog)
+        
         # set up instrumented attributes with backrefs    
         manager.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True)
         manager.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True)
@@ -234,12 +250,14 @@ class AttributesTest(PersistTest):
         assert getattr(Blog, 'posts').hasparent(p2)
         assert getattr(Post, 'blog').hasparent(b2)
         
-    def testinheritance(self):
+    def test_inheritance(self):
         """tests that attributes are polymorphic"""
         class Foo(object):pass
         class Bar(Foo):pass
         
         manager = attributes.AttributeManager()
+        manager.register_class(Foo)
+        manager.register_class(Bar)
         
         def func1():
             print "func1"
@@ -261,12 +279,14 @@ class AttributesTest(PersistTest):
         assert x.element2 == 'this is the shared attr'
         assert y.element2 == 'this is the shared attr'
 
-    def testinheritance2(self):
+    def test_inheritance2(self):
         """test that the attribute manager can properly traverse the managed attributes of an object,
         if the object is of a descendant class with managed attributes in the parent class"""
         class Foo(object):pass
         class Bar(Foo):pass
         manager = attributes.AttributeManager()
+        manager.register_class(Foo)
+        manager.register_class(Bar)
         manager.register_attribute(Foo, 'element', uselist=False)
         x = Bar()
         x.element = 'this is the element'
@@ -277,7 +297,7 @@ class AttributesTest(PersistTest):
         assert hist.added_items() == []
         assert hist.unchanged_items() == ['this is the element']
 
-    def testlazyhistory(self):
+    def test_lazyhistory(self):
         """tests that history functions work with lazy-loading attributes"""
         class Foo(object):pass
         class Bar(object):
@@ -287,6 +307,8 @@ class AttributesTest(PersistTest):
                 return "Bar: id %d" % self.id
                 
         manager = attributes.AttributeManager()
+        manager.register_class(Foo)
+        manager.register_class(Bar)
 
         def func1():
             return "this is func 1"
@@ -305,11 +327,13 @@ class AttributesTest(PersistTest):
         print h.unchanged_items()
 
         
-    def testparenttrack(self):    
+    def test_parenttrack(self):    
         class Foo(object):pass
         class Bar(object):pass
         
         manager = attributes.AttributeManager()
+        manager.register_class(Foo)
+        manager.register_class(Bar)
         
         manager.register_attribute(Foo, 'element', uselist=False, trackparent=True)
         manager.register_attribute(Bar, 'element', uselist=False, trackparent=True)
@@ -330,10 +354,11 @@ class AttributesTest(PersistTest):
         b2.element = None
         assert not getattr(Bar, 'element').hasparent(f2)
 
-    def testmutablescalars(self):
+    def test_mutablescalars(self):
         """test detection of changes on mutable scalar items"""
         class Foo(object):pass
         manager = attributes.AttributeManager()
+        manager.register_class(Foo)
         manager.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x], mutable_scalars=True)
         x = Foo()
         x.element = ['one', 'two', 'three']    
@@ -341,8 +366,9 @@ class AttributesTest(PersistTest):
         x.element[1] = 'five'
         assert manager.is_modified(x)
         
-        manager.reset_class_managed(Foo)
+        manager.unregister_class(Foo)
         manager = attributes.AttributeManager()
+        manager.register_class(Foo)
         manager.register_attribute(Foo, 'element', uselist=False)
         x = Foo()
         x.element = ['one', 'two', 'three']    
@@ -350,7 +376,7 @@ class AttributesTest(PersistTest):
         x.element[1] = 'five'
         assert not manager.is_modified(x)
         
-    def testdescriptorattributes(self):
+    def test_descriptorattributes(self):
         """changeset: 1633 broke ability to use ORM to map classes with unusual
         descriptor attributes (for example, classes that inherit from ones
         implementing zope.interface.Interface).
@@ -363,11 +389,12 @@ class AttributesTest(PersistTest):
             A = des()
 
         manager = attributes.AttributeManager()
-        manager.reset_class_managed(Foo)
+        manager.unregister_class(Foo)
     
-    def testcollectionclasses(self):
+    def test_collectionclasses(self):
         manager = attributes.AttributeManager()
         class Foo(object):pass
+        manager.register_class(Foo)
         manager.register_attribute(Foo, "collection", uselist=True, typecallable=set)
         assert isinstance(Foo().collection, set)
         
index 0cc8cf7e06ae5802636da3b020bea1e6d8b78011..9d5ae7ab92281d35ee093e6c8eb7d5e0a4486324 100644 (file)
@@ -36,6 +36,7 @@ class Entity(object):
         return str((id(self), self.a, self.b, self.c))
 
 manager = attributes.AttributeManager()
+manager.register_class(Entity)
 
 _id = 1
 def entity_maker():
@@ -55,6 +56,7 @@ class CollectionsTest(PersistTest):
             pass
 
         canary = Canary()
+        manager.register_class(Foo)
         manager.register_attribute(Foo, 'attr', True, extension=canary,
                                    typecallable=typecallable)
 
@@ -92,6 +94,7 @@ class CollectionsTest(PersistTest):
             pass
         
         canary = Canary()
+        manager.register_class(Foo)
         manager.register_attribute(Foo, 'attr', True, extension=canary,
                                    typecallable=typecallable)
 
@@ -233,6 +236,7 @@ class CollectionsTest(PersistTest):
             pass
 
         canary = Canary()
+        manager.register_class(Foo)
         manager.register_attribute(Foo, 'attr', True, extension=canary,
                                    typecallable=typecallable)
 
@@ -341,6 +345,7 @@ class CollectionsTest(PersistTest):
             pass
 
         canary = Canary()
+        manager.register_class(Foo)
         manager.register_attribute(Foo, 'attr', True, extension=canary,
                                    typecallable=typecallable)
 
@@ -473,6 +478,7 @@ class CollectionsTest(PersistTest):
             pass
 
         canary = Canary()
+        manager.register_class(Foo)
         manager.register_attribute(Foo, 'attr', True, extension=canary,
                                    typecallable=typecallable)
 
@@ -577,6 +583,7 @@ class CollectionsTest(PersistTest):
             pass
 
         canary = Canary()
+        manager.register_class(Foo)
         manager.register_attribute(Foo, 'attr', True, extension=canary,
                                    typecallable=typecallable)
 
@@ -694,6 +701,7 @@ class CollectionsTest(PersistTest):
             pass
 
         canary = Canary()
+        manager.register_class(Foo)
         manager.register_attribute(Foo, 'attr', True, extension=canary,
                                    typecallable=typecallable)
 
@@ -868,6 +876,7 @@ class CollectionsTest(PersistTest):
             pass
         
         canary = Canary()
+        manager.register_class(Foo)
         manager.register_attribute(Foo, 'attr', True, extension=canary,
                                    typecallable=typecallable)
 
@@ -1001,6 +1010,7 @@ class CollectionsTest(PersistTest):
         class Foo(object):
             pass
         canary = Canary()
+        manager.register_class(Foo)
         manager.register_attribute(Foo, 'attr', True, extension=canary,
                                    typecallable=Custom)
 
@@ -1070,6 +1080,7 @@ class CollectionsTest(PersistTest):
 
         canary = Canary()
         creator = entity_maker
+        manager.register_class(Foo)
         manager.register_attribute(Foo, 'attr', True, extension=canary)
 
         obj = Foo()
index ad438c1faa195185c61a2eb8168f7736fc109155..38696e85b99bffba9a07c2e8456c10af000f0ad2 100644 (file)
@@ -35,7 +35,7 @@ def load():
         #print l
         subitems.insert().execute(*l)    
 
-@profiling.profiled('masseagerload', always=True)
+@profiling.profiled('masseagerload', always=True, sort=['cumulative'])
 def masseagerload(session):
     query = session.query(Item)
     l = query.select()