]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- pared down private and semi-private functions in the attributes package.
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 23 Feb 2009 00:08:37 +0000 (00:08 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 23 Feb 2009 00:08:37 +0000 (00:08 +0000)
- simplified the process of establishment and unestablishment of
class management from a mapper perspective; class manager setup/teardown
is now symmetric (ClassManager would never be fully de-associated previously).
- class manager now unconditionally decorates __init__.  this has a slight
behavior change for an unmapped subclass of a mapped superclass, in that
InstanceState creation corresponds to that of the superclass.  This
still doesn't allow unmapped subclasses to be usable in mapper
situations, though.

CHANGES
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/mapper.py
test/orm/extendedattr.py
test/orm/instrumentation.py

diff --git a/CHANGES b/CHANGES
index ddff28fb7c5860a28f9c743e12eb301d9c81f0fa..5848a679656d1f7d3d621e18c3e375cfe4f4beeb 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -61,7 +61,7 @@ CHANGES
      - Query won't fail with weakref error when a non-mapper/class
        instrumented descriptor is passed, raises 
        "Invalid column expession".
-       
+     
 - sql
     - Fixed missing _label attribute on Function object, others
       when used in a select() with use_labels (such as when used
index d95acf0c7ba68afa7276a8cacfc36948af5704ca..b39ef990ebb31366994dfee775237e2eb25986b6 100644 (file)
@@ -1141,12 +1141,14 @@ class ClassManager(dict):
 
     event_registry_factory = Events
     instance_state_factory = InstanceState
-
+    deferred_scalar_loader = None
+    
     def __init__(self, class_):
         self.class_ = class_
         self.factory = None  # where we came from, for inheritance bookkeeping
         self.info = {}
         self.mapper = None
+        self.new_init = None
         self.mutable_attributes = set()
         self.local_attrs = {}
         self.originals = {}
@@ -1156,28 +1158,73 @@ class ClassManager(dict):
             cls_state = manager_of_class(base)
             if cls_state:
                 self.update(cls_state)
-        self.registered = False
-        self._instantiable = False
         self.events = self.event_registry_factory()
-
-    def instantiable(self, boolean):
-        # experiment, probably won't stay in this form
-        assert boolean ^ self._instantiable, (boolean, self._instantiable)
-        if boolean:
-            self.events.original_init = self.class_.__init__
-            new_init = _generate_init(self.class_, self)
-            self.install_member('__init__', new_init)
-        else:
+        self.manage()
+        self._instrument_init()
+    
+    def _configure_create_arguments(self, 
+                            _source=None, 
+                            instance_state_factory=None, 
+                            deferred_scalar_loader=None):
+        """Accept extra **kw arguments passed to create_manager_for_cls.
+        
+        The current contract of ClassManager and other managers is that they
+        take a single "cls" argument in their constructor (as per 
+        test/orm/instrumentation.py InstrumentationCollisionTest).  This
+        is to provide consistency with the current API of "class manager"
+        callables and such which may return various ClassManager and 
+        ClassManager-like instances.   So create_manager_for_cls sends
+        in ClassManager-specific arguments via this method once the 
+        non-proxied ClassManager is available.
+        
+        """
+        if _source:
+            instance_state_factory = _source.instance_state_factory
+            deferred_scalar_loader = _source.deferred_scalar_loader
+
+        if instance_state_factory:
+            self.instance_state_factory = instance_state_factory
+        if deferred_scalar_loader:
+            self.deferred_scalar_loader = deferred_scalar_loader
+    
+    def _subclass_manager(self, cls):
+        """Create a new ClassManager for a subclass of this ClassManager's class.
+        
+        This is called automatically when attributes are instrumented so that
+        the attributes can be propagated to subclasses against their own
+        class-local manager, without the need for mappers etc. to have already
+        pre-configured managers for the full class hierarchy.   Mappers
+        can post-configure the auto-generated ClassManager when needed.
+        
+        """
+        manager = manager_of_class(cls)
+        if manager is None:
+            manager = _create_manager_for_cls(cls, _source=self)
+        return manager
+        
+    def _instrument_init(self):
+        # TODO: self.class_.__init__ is often the already-instrumented
+        # __init__ from an instrumented superclass.  We still need to make 
+        # our own wrapper, but it would
+        # be nice to wrap the original __init__ and not our existing wrapper
+        # of such, since this adds method overhead.
+        self.events.original_init = self.class_.__init__
+        self.new_init = _generate_init(self.class_, self)
+        self.install_member('__init__', self.new_init)
+        
+    def _uninstrument_init(self):
+        if self.new_init:
             self.uninstall_member('__init__')
-        self._instantiable = bool(boolean)
-    instantiable = property(lambda s: s._instantiable, instantiable)
+            self.new_init = None
 
     def manage(self):
         """Mark this instance as the manager for its class."""
+        
         setattr(self.class_, self.MANAGER_ATTR, self)
 
     def dispose(self):
-        """Dissasociate this instance from its class."""
+        """Dissasociate this manager from its class."""
+        
         delattr(self.class_, self.MANAGER_ATTR)
 
     def manager_getter(self):
@@ -1194,9 +1241,7 @@ class ClassManager(dict):
         for cls in self.class_.__subclasses__():
             if isinstance(cls, types.ClassType):
                 continue
-            manager = manager_of_class(cls)
-            if manager is None:
-                manager = create_manager_for_cls(cls)
+            manager = self._subclass_manager(cls)
             manager.instrument_attribute(key, inst, True)
 
     def post_configure_attribute(self, key):
@@ -1217,16 +1262,20 @@ class ClassManager(dict):
         for cls in self.class_.__subclasses__():
             if isinstance(cls, types.ClassType):
                 continue
-            manager = manager_of_class(cls)
-            if manager is None:
-                manager = create_manager_for_cls(cls)
+            manager = self._subclass_manager(cls)
             manager.uninstrument_attribute(key, True)
 
     def unregister(self):
+        """remove all instrumentation established by this ClassManager."""
+        
+        self._uninstrument_init()
+
+        self.mapper = self.events = None
+        self.info.clear()
+        
         for key in list(self):
             if key in self.local_attrs:
                 self.uninstrument_attribute(key)
-        self.registered = False
 
     def install_descriptor(self, key, inst):
         if key in (self.STATE_ATTR, self.MANAGER_ATTR):
@@ -1271,15 +1320,6 @@ class ClassManager(dict):
     def attributes(self):
         return self.itervalues()
 
-    @classmethod
-    def deferred_scalar_loader(cls, state, keys):
-        """Apply a scalar loader to the given state.
-        
-        Unimplemented by default, is patched
-        by the mapper.
-        
-        """
-
     ## InstanceState management
 
     def new_instance(self, state=None):
@@ -1337,9 +1377,9 @@ class ClassManager(dict):
 class _ClassInstrumentationAdapter(ClassManager):
     """Adapts a user-defined InstrumentationManager to a ClassManager."""
 
-    def __init__(self, class_, override):
-        ClassManager.__init__(self, class_)
+    def __init__(self, class_, override, **kw):
         self._adapted = override
+        ClassManager.__init__(self, class_, **kw)
 
     def manage(self):
         self._adapted.manage(self.class_, self)
@@ -1557,25 +1597,21 @@ def has_parent(cls, obj, key, optimistic=False):
     state = instance_state(obj)
     return manager.has_parent(state, key, optimistic)
 
-def register_class(class_):
-    """TODO"""
-
-    # TODO: what's this function for ?  why would I call this and not
-    # create_manager_for_cls ?
+def register_class(class_, **kw):
+    """Register class instrumentation.
+    
+    Returns the existing or newly created class manager.
+    """
 
     manager = manager_of_class(class_)
     if manager is None:
-        manager = create_manager_for_cls(class_)
-    if not manager.instantiable:
-        manager.instantiable = True
-
+        manager = _create_manager_for_cls(class_, **kw)
+    return manager
+    
 def unregister_class(class_):
-    """TODO"""
-    manager = manager_of_class(class_)
-    assert manager
-    assert manager.instantiable
-    manager.instantiable = False
-    manager.unregister()
+    """Unregister class instrumentation."""
+    
+    instrumentation_registry.unregister(class_)
 
 def register_attribute(class_, key, **kw):
 
@@ -1587,18 +1623,34 @@ def register_attribute(class_, key, **kw):
     if not proxy_property:
         register_attribute_impl(class_, key, **kw)
     
-def register_attribute_impl(class_, key, **kw):
+def register_attribute_impl(class_, key,         
+        uselist=False, callable_=None, 
+        useobject=False, mutable_scalars=False, 
+        impl_class=None, **kw):
     
     manager = manager_of_class(class_)
-    uselist = kw.get('uselist', False)
     if uselist:
         factory = kw.pop('typecallable', None)
         typecallable = manager.instrument_collection_class(
             key, factory or list)
     else:
         typecallable = kw.pop('typecallable', None)
-        
-    manager[key].impl = _create_prop(class_, key, manager, typecallable=typecallable, **kw)
+
+    if impl_class:
+        impl = impl_class(class_, key, typecallable, **kw)
+    elif uselist:
+        impl = CollectionAttributeImpl(class_, key, callable_,
+                                       typecallable=typecallable, **kw)
+    elif useobject:
+        impl = ScalarObjectAttributeImpl(class_, key, callable_, **kw)
+    elif mutable_scalars:
+        impl = MutableScalarAttributeImpl(class_, key, callable_,
+                                          class_manager=manager, **kw)
+    else:
+        impl = ScalarAttributeImpl(class_, key, callable_, **kw)
+
+    manager[key].impl = impl
+    
     manager.post_configure_attribute(key)
     
 def register_descriptor(class_, key, proxy_property=None, comparator=None, parententity=None, property_=None):
@@ -1712,11 +1764,11 @@ def is_instrumented(instance, key):
 class InstrumentationRegistry(object):
     """Private instrumentation registration singleton."""
 
-    manager_finders = weakref.WeakKeyDictionary()
-    state_finders = util.WeakIdentityMapping()
-    extended = False
+    _manager_finders = weakref.WeakKeyDictionary()
+    _state_finders = util.WeakIdentityMapping()
+    _extended = False
 
-    def create_manager_for_cls(self, class_):
+    def create_manager_for_cls(self, class_, **kw):
         assert class_ is not None
         assert manager_of_class(class_) is None
 
@@ -1727,9 +1779,9 @@ class InstrumentationRegistry(object):
         else:
             factory = ClassManager
 
-        existing_factories = collect_management_factories_for(class_)
-        existing_factories.add(factory)
-        if len(existing_factories) > 1:
+        existing_factories = self._collect_management_factories_for(class_).\
+                                difference([factory])
+        if existing_factories:
             raise TypeError(
                 "multiple instrumentation implementations specified "
                 "in %s inheritance hierarchy: %r" % (
@@ -1738,21 +1790,49 @@ class InstrumentationRegistry(object):
         manager = factory(class_)
         if not isinstance(manager, ClassManager):
             manager = _ClassInstrumentationAdapter(class_, manager)
-        if factory != ClassManager and not self.extended:
-            self.extended = True
+            
+        if factory != ClassManager and not self._extended:
+            self._extended = True
             _install_lookup_strategy(self)
+        
+        manager._configure_create_arguments(**kw)
 
         manager.factory = factory
-        manager.manage()
-        self.manager_finders[class_] = manager.manager_getter()
-        self.state_finders[class_] = manager.state_getter()
+        self._manager_finders[class_] = manager.manager_getter()
+        self._state_finders[class_] = manager.state_getter()
         return manager
 
+    def _collect_management_factories_for(self, cls):
+        """Return a collection of factories in play or specified for a hierarchy.
+
+        Traverses the entire inheritance graph of a cls and returns a collection
+        of instrumentation factories for those classes.  Factories are extracted
+        from active ClassManagers, if available, otherwise
+        instrumentation_finders is consulted.
+
+        """
+        hierarchy = util.class_hierarchy(cls)
+        factories = set()
+        for member in hierarchy:
+            manager = manager_of_class(member)
+            if manager is not None:
+                factories.add(manager.factory)
+            else:
+                for finder in instrumentation_finders:
+                    factory = finder(member)
+                    if factory is not None:
+                        break
+                else:
+                    factory = None
+                factories.add(factory)
+        factories.discard(None)
+        return factories
+
     def manager_of_class(self, cls):
         if cls is None:
             return None
         try:
-            finder = self.manager_finders[cls]
+            finder = self._manager_finders[cls]
         except KeyError:
             return None
         else:
@@ -1762,7 +1842,7 @@ class InstrumentationRegistry(object):
         if instance is None:
             raise AttributeError("None has no persistent state.")
         try:
-            return self.state_finders[instance.__class__](instance)
+            return self._state_finders[instance.__class__](instance)
         except KeyError:
             raise AttributeError("%r is not instrumented" % instance.__class__)
 
@@ -1770,7 +1850,7 @@ class InstrumentationRegistry(object):
         if instance is None:
             return default
         try:
-            finder = self.state_finders[instance.__class__]
+            finder = self._state_finders[instance.__class__]
         except KeyError:
             return default
         else:
@@ -1782,49 +1862,33 @@ class InstrumentationRegistry(object):
                 raise
 
     def unregister(self, class_):
-        if class_ in self.manager_finders:
+        if class_ in self._manager_finders:
             manager = self.manager_of_class(class_)
+            manager.unregister()
             manager.dispose()
-            del self.manager_finders[class_]
-            del self.state_finders[class_]
-
-# Create a registry singleton and prepare placeholders for lookup functions.
+            del self._manager_finders[class_]
+            del self._state_finders[class_]
 
 instrumentation_registry = InstrumentationRegistry()
 
-create_manager_for_cls = None
-
-manager_of_class = None
-
-instance_state = None
-
-
-_lookup_strategy = None
-
 def _install_lookup_strategy(implementation):
-    """Switch between native and extended instrumentation modes.
-
-    Completely private.  Use the instrumentation_finders interface to
-    inject global instrumentation behavior.
-
+    """Replace global class/object management functions
+    with either faster or more comprehensive implementations,
+    based on whether or not extended class instrumentation
+    has been detected.
+    
+    This function is called only by InstrumentationRegistry()
+    and unit tests specific to this behavior.
+    
     """
-    global manager_of_class, instance_state, create_manager_for_cls
-    global _lookup_strategy
-
-    # Using a symbol here to make debugging a little friendlier.
-    if implementation is not util.symbol('native'):
-        manager_of_class = implementation.manager_of_class
-        instance_state = implementation.state_of
-        create_manager_for_cls = implementation.create_manager_for_cls
-    else:
-        def manager_of_class(class_):
-            return getattr(class_, ClassManager.MANAGER_ATTR, None)
-        manager_of_class = instrumentation_registry.manager_of_class
+    global instance_state
+    if implementation is util.symbol('native'):
         instance_state = attrgetter(ClassManager.STATE_ATTR)
-        create_manager_for_cls = instrumentation_registry.create_manager_for_cls
-    # TODO: maybe log an event when setting a strategy.
-    _lookup_strategy = implementation
-
+    else:
+        instance_state = instrumentation_registry.state_of
+    
+manager_of_class = instrumentation_registry.manager_of_class
+_create_manager_for_cls = instrumentation_registry.create_manager_for_cls
 _install_lookup_strategy(util.symbol('native'))
 
 def find_native_user_instrumentation_hook(cls):
@@ -1832,54 +1896,12 @@ def find_native_user_instrumentation_hook(cls):
     return getattr(cls, INSTRUMENTATION_MANAGER, None)
 instrumentation_finders.append(find_native_user_instrumentation_hook)
 
-def collect_management_factories_for(cls):
-    """Return a collection of factories in play or specified for a hierarchy.
-
-    Traverses the entire inheritance graph of a cls and returns a collection
-    of instrumentation factories for those classes.  Factories are extracted
-    from active ClassManagers, if available, otherwise
-    instrumentation_finders is consulted.
-
-    """
-    hierarchy = util.class_hierarchy(cls)
-    factories = set()
-    for member in hierarchy:
-        manager = manager_of_class(member)
-        if manager is not None:
-            factories.add(manager.factory)
-        else:
-            for finder in instrumentation_finders:
-                factory = finder(member)
-                if factory is not None:
-                    break
-            else:
-                factory = None
-            factories.add(factory)
-    factories.discard(None)
-    return factories
-
-def _create_prop(class_, key, class_manager, 
-                    uselist=False, callable_=None, typecallable=None, 
-                    useobject=False, mutable_scalars=False, 
-                    impl_class=None, **kwargs):
-    if impl_class:
-        return impl_class(class_, key, typecallable, **kwargs)
-    elif uselist:
-        return CollectionAttributeImpl(class_, key, callable_,
-                                       typecallable=typecallable,
-                                       **kwargs)
-    elif useobject:
-        return ScalarObjectAttributeImpl(class_, key, callable_,
-                                         **kwargs)
-    elif mutable_scalars:
-        return MutableScalarAttributeImpl(class_, key, callable_,
-                                          class_manager=class_manager, **kwargs)
-    else:
-        return ScalarAttributeImpl(class_, key, callable_, **kwargs)
-
 def _generate_init(class_, class_manager):
     """Build an __init__ decorator that triggers ClassManager events."""
 
+    # TODO: we should use the ClassManager's notion of the 
+    # original '__init__' method, once ClassManager is fixed
+    # to always reference that.
     original__init__ = class_.__init__
     assert original__init__
 
@@ -1897,7 +1919,6 @@ def __init__(%(apply_pos)s):
 """
     func_vars = util.format_argspec_init(original__init__, grouped=False)
     func_text = func_body % func_vars
-    #TODO: log debug #print func_text
 
     func = getattr(original__init__, 'im_func', original__init__)
     func_defaults = getattr(func, 'func_defaults', None)
index 81d5404d77b54062f80561ccc9fc912dd3a497d7..98a027448eb9a5c2c560e616d956266ac1c8f373 100644 (file)
@@ -338,22 +338,28 @@ class Mapper(object):
             return
 
         if manager is not None:
-            if manager.class_ is not self.class_:
-                # An inherited manager.  Install one for this subclass.
-                # TODO: no coverage here
-                manager = None
-            elif manager.mapper:
+            assert manager.class_ is self.class_
+            if manager.mapper:
                 raise sa_exc.ArgumentError(
                     "Class '%s' already has a primary mapper defined. "
                     "Use non_primary=True to "
                     "create a non primary Mapper.  clear_mappers() will "
                     "remove *all* current mappers from all classes." %
                     self.class_)
-
+            #else:
+                # a ClassManager may already exist as 
+                # ClassManager.instrument_attribute() creates 
+                # new managers for each subclass if they don't yet exist.
+                
         _mapper_registry[self] = True
 
+        self.extension.instrument_class(self, self.class_)
+
         if manager is None:
-            manager = attributes.create_manager_for_cls(self.class_)
+            manager = attributes.register_class(self.class_, 
+                instance_state_factory = IdentityManagedState,
+                deferred_scalar_loader = _load_scalar_attributes
+            )
 
         self.class_manager = manager
 
@@ -363,12 +369,6 @@ class Mapper(object):
         if manager.info.get(_INSTRUMENTOR, False):
             return
 
-        self.extension.instrument_class(self, self.class_)
-
-        manager.instantiable = True
-        manager.instance_state_factory = IdentityManagedState
-        manager.deferred_scalar_loader = _load_scalar_attributes
-
         event_registry = manager.events
         event_registry.add_listener('on_init', _event_on_init)
         event_registry.add_listener('on_init_failure', _event_on_init_failure)
@@ -390,16 +390,11 @@ class Mapper(object):
     def dispose(self):
         # Disable any attribute-based compilation.
         self.compiled = True
-        manager = self.class_manager
+        
         if hasattr(self, '_compile_failed'):
             del self._compile_failed
-        if not self.non_primary and manager.mapper is self:
-            manager.mapper = None
-            manager.events.remove_listener('on_init', _event_on_init)
-            manager.events.remove_listener('on_init_failure',
-                                           _event_on_init_failure)
-            manager.uninstall_member('__init__')
-            del manager.info[_INSTRUMENTOR]
+            
+        if not self.non_primary and self.class_manager.mapper is self:
             attributes.unregister_class(self.class_)
 
     def _configure_pks(self):
index 6b2c04b71ea0208a0e018fed50dbaab45fdfcf82..69164ebafb41f638ca6b60fa4884c918b31e840e 100644 (file)
@@ -158,7 +158,7 @@ class UserDefinedExtensionTest(_base.ORMTest):
             attributes.register_attribute(Foo, 'a', uselist=False, useobject=False)
             attributes.register_attribute(Foo, 'b', uselist=False, useobject=False)
             
-            assert Foo in attributes.instrumentation_registry.state_finders
+            assert Foo in attributes.instrumentation_registry._state_finders
             f = Foo()
             attributes.instance_state(f).expire_attributes(None)
             self.assertEquals(f.a, "this is a")
index a9d186632b9b36884c59e583543f4c241ea67075..081c46cdd886540738d916e335e71fc7462f6653 100644 (file)
@@ -2,7 +2,7 @@ import testenv; testenv.configure_for_tests()
 
 from testlib import sa
 from testlib.sa import MetaData, Table, Column, Integer, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session, attributes
+from testlib.sa.orm import mapper, relation, create_session, attributes, class_mapper
 from testlib.testing import eq_, ne_
 from testlib.compat import _function_named
 from orm import _base
@@ -21,12 +21,11 @@ def modifies_instrumentation_finders(fn):
 def with_lookup_strategy(strategy):
     def decorate(fn):
         def wrapped(*args, **kw):
-            current = attributes._lookup_strategy
             try:
                 attributes._install_lookup_strategy(strategy)
                 return fn(*args, **kw)
             finally:
-                attributes._install_lookup_strategy(current)
+                attributes._install_lookup_strategy(sa.util.symbol('native'))
         return _function_named(wrapped, fn.func_name)
     return decorate
 
@@ -454,10 +453,10 @@ class MapperInitTest(_base.ORMTest):
             pass
 
         class C(B):
-            def __init__(self):
+            def __init__(self, x):
                 pass
 
-        mapper(A, self.fixture())
+        m = mapper(A, self.fixture())
 
         a = attributes.instance_state(A())
         assert isinstance(a, attributes.InstanceState)
@@ -467,11 +466,19 @@ class MapperInitTest(_base.ORMTest):
         assert isinstance(b, attributes.InstanceState)
         assert type(b) is not attributes.InstanceState
 
-        # C is unmanaged
-        cobj = C()
-        self.assertRaises((AttributeError, TypeError),
-                          attributes.instance_state, cobj)
+        # B is not mapped in the current implementation
+        self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, B)
+
+        # the constructor of C is decorated too.  
+        # we don't support unmapped subclasses in any case,
+        # users should not be expecting any particular behavior
+        # from this scenario.
+        c = attributes.instance_state(C(3))
+        assert isinstance(c, attributes.InstanceState)
+        assert type(c) is not attributes.InstanceState
 
+        # C is not mapped in the current implementation
+        self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, C)
 
 class InstrumentationCollisionTest(_base.ORMTest):
     def test_none(self):
@@ -653,7 +660,16 @@ class MiscTest(_base.ORMTest):
 
         a = A()
         assert not a.bs
-
+    
+    def test_uninstrument(self):
+        class A(object):pass
+        
+        manager = attributes.register_class(A)
+        
+        assert attributes.manager_of_class(A) is manager
+        attributes.unregister_class(A)
+        assert attributes.manager_of_class(A) is None
+        
     def test_compileonattr_rel_backref_a(self):
         m = MetaData()
         t1 = Table('t1', m,