]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- most tests passing on adapted MapperExtension
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 6 Nov 2010 20:43:48 +0000 (16:43 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 6 Nov 2010 20:43:48 +0000 (16:43 -0400)
lib/sqlalchemy/event.py
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/deprecated_interfaces.py
lib/sqlalchemy/orm/events.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/state.py
test/orm/test_generative.py
test/orm/test_mapper.py

index 0f6342e6b2844171772119bcfed373a4936a5e24..379c3f1dd022d065982e1b58adadb0512d27284e 100644 (file)
@@ -68,11 +68,12 @@ class _Dispatch(object):
             object."""
 
         for ls in other.descriptors:
-            existing_listeners = getattr(self, ls.name).listeners
-            existing_listener_set = set(existing_listeners)
-            existing_listeners.extend([l for l 
-                                    in ls.listeners 
-                                    if l not in existing_listener_set])
+            getattr(self, ls.name).update(ls)
+            #existing_listeners = getattr(self, ls.name).listeners
+            #existing_listener_set = set(existing_listeners)
+            #existing_listeners.extend([l for l 
+            #                        in ls.listeners 
+            #                        if l not in existing_listener_set])
 
 class _EventMeta(type):
     """Intercept new Event subclasses and create 
@@ -198,7 +199,17 @@ class _ListenerCollection(object):
         
     def __nonzero__(self):
         return bool(self.listeners or self.parent_listeners)
-        
+    
+    def update(self, other):
+        """Populate from the listeners in another :class:`_Dispatch`
+            object."""
+
+        existing_listeners = self.listeners
+        existing_listener_set = set(existing_listeners)
+        existing_listeners.extend([l for l 
+                                in other.listeners 
+                                if l not in existing_listener_set])
+
     def append(self, obj, target):
         if obj not in self.listeners:
             self.listeners.append(obj)
index 1c5e630cf47cdf93fed820e353b2d451e9dc9e36..6df13e94f02d1b2f6204eaee545cae92c677ea9a 100644 (file)
@@ -82,7 +82,6 @@ __all__ = (
     'dynamic_loader',
     'eagerload',
     'eagerload_all',
-    'extension',
     'immediateload',
     'join',
     'joinedload',
index 2145bef4b9d5d1b5342505680dd2888b7a6c2555..3817dc2eee3a6687130c6b357c76a4f3cca54e2c 100644 (file)
@@ -1,4 +1,4 @@
-from sqlalchemy import event
+from sqlalchemy import event, util
 from interfaces import EXT_CONTINUE
 
 
@@ -56,20 +56,44 @@ class MapperExtension(object):
         
     @classmethod
     def _adapt_listener_methods(cls, self, listener, methods):
+        
         for meth in methods:
             me_meth = getattr(MapperExtension, meth)
             ls_meth = getattr(listener, meth)
+            
             # TODO: comparing self.methods to cls.method, 
             # this comparison is probably moot
+            
             if me_meth is not ls_meth:
                 if meth == 'reconstruct_instance':
                     def go(ls_meth):
                         def reconstruct(instance):
                             ls_meth(self, instance)
                         return reconstruct
-                    event.listen(go(ls_meth), 'on_load', self.class_manager, raw=False)
+                    event.listen(go(ls_meth), 'on_load', 
+                                        self.class_manager, raw=False)
+                elif meth == 'init_instance':
+                    def go(ls_meth):
+                        def init_instance(instance, args, kwargs):
+                            ls_meth(self, self.class_, 
+                                        self.class_manager.original_init, 
+                                        instance, args, kwargs)
+                        return init_instance
+                    event.listen(go(ls_meth), 'on_init', 
+                                            self.class_manager, raw=False)
+                elif meth == 'init_failed':
+                    def go(ls_meth):
+                        def init_failed(instance, args, kwargs):
+                            util.warn_exception(ls_meth, self, self.class_, 
+                                            self.class_manager.original_init, 
+                                            instance, args, kwargs)
+                            
+                        return init_failed
+                    event.listen(go(ls_meth), 'on_init_failure', 
+                                        self.class_manager, raw=False)
                 else:
-                    event.listen(ls_meth, "on_%s" % meth, self, raw=False, retval=True)
+                    event.listen(ls_meth, "on_%s" % meth, self, 
+                                        raw=False, retval=True)
 
 
     def instrument_class(self, mapper, class_):
index c3eab67e146ddf960ff0c17b46dc03b54035fd0c..5765b9ea8bafcc69c2e8ea24dc9c126c6d5a12b8 100644 (file)
@@ -83,13 +83,39 @@ class InstanceEvents(event.Events):
         raise NotImplementedError("Removal of instance events not yet implemented")
         
     def on_init(self, target, args, kwargs):
-        """"""
+        """Receive an instance when it's constructor is called.
+        
+        This method is only called during a userland construction of 
+        an object.  It is not called when an object is loaded from the
+        database.
+
+        """
         
     def on_init_failure(self, target, args, kwargs):
-        """"""
+        """Receive an instance when it's constructor has been called, 
+        and raised an exception.
+        
+        This method is only called during a userland construction of 
+        an object.  It is not called when an object is loaded from the
+        database.
+
+        """
     
     def on_load(self, target):
-        """"""
+        """Receive an object instance after it has been created via
+        ``__new__``, and after initial attribute population has
+        occurred.
+
+        This typically occurs when the instance is created based on
+        incoming result rows, and is only called once for that
+        instance's lifetime.
+
+        Note that during a result-row load, this method is called upon
+        the first row received for this instance.  Note that some 
+        attributes and collections may or may not be loaded or even 
+        initialized, depending on what's present in the result rows.
+
+        """
     
     def on_resurrect(self, target):
         """"""
@@ -180,31 +206,6 @@ class MapperEvents(event.Events):
         
         """
 
-    def on_init_instance(self, mapper, class_, oldinit, target, args, kwargs):
-        """Receive an instance when it's constructor is called.
-        
-        This method is only called during a userland construction of 
-        an object.  It is not called when an object is loaded from the
-        database.
-        
-        The return value is only significant within the ``MapperExtension`` 
-        chain; the parent mapper's behavior isn't modified by this method.
-
-        """
-
-    def on_init_failed(self, mapper, class_, oldinit, target, args, kwargs):
-        """Receive an instance when it's constructor has been called, 
-        and raised an exception.
-        
-        This method is only called during a userland construction of 
-        an object.  It is not called when an object is loaded from the
-        database.
-        
-        The return value is only significant within the ``MapperExtension`` 
-        chain; the parent mapper's behavior isn't modified by this method.
-
-        """
-
     def on_translate_row(self, mapper, context, row):
         """Perform pre-processing on the given result row and return a
         new row instance.
@@ -306,25 +307,6 @@ class MapperEvents(event.Events):
 
         """
 
-    def on_reconstruct_instance(self, mapper, target):
-        """Receive an object instance after it has been created via
-        ``__new__``, and after initial attribute population has
-        occurred.
-
-        This typically occurs when the instance is created based on
-        incoming result rows, and is only called once for that
-        instance's lifetime.
-
-        Note that during a result-row load, this method is called upon
-        the first row received for this instance.  Note that some 
-        attributes and collections may or may not be loaded or even 
-        initialized, depending on what's present in the result rows.
-
-        The return value is only significant within the ``MapperExtension`` 
-        chain; the parent mapper's behavior isn't modified by this method.
-
-        """
-
     def on_before_insert(self, mapper, connection, target):
         """Receive an object instance before that instance is inserted
         into its table.
@@ -341,7 +323,6 @@ class MapperEvents(event.Events):
 
         """
 
-
     def on_after_insert(self, mapper, connection, target):
         """Receive an object instance after that instance is inserted.
         
index 643ce6faa54e7fb4e63c58328ddb515aeb100fb1..03e313685e04c416dfa268883bb6333a8fefe0d9 100644 (file)
@@ -141,8 +141,8 @@ class Mapper(object):
         self._inherits_equated_pairs = None
         self._memoized_values = {}
         self._compiled_cache_size = _compiled_cache_size
-
-        self._deprecated_extensions = extension
+        self._reconstructor = None
+        self._deprecated_extensions = util.to_list(extension or [])
         
         if allow_null_pks:
             util.warn_deprecated(
@@ -322,19 +322,32 @@ class Mapper(object):
                     % self)
     
     def _configure_legacy_instrument_class(self):
-        # TODO: tests failing
-        for ext in util.to_list(self._deprecated_extensions or []):
-            ext._adapt_instrument_class(self, ext)
+
+        if self.inherits:
+            self.dispatch.update(self.inherits.dispatch)
+            super_extensions = set(chain(*[m._deprecated_extensions 
+                                    for m in self.inherits.iterate_to_root()]))
+        else:
+            super_extensions = set()
+            
+        for ext in self._deprecated_extensions:
+            if ext not in super_extensions:
+                ext._adapt_instrument_class(self, ext)
 
     def _configure_listeners(self):
-        # TODO: this has to be made smarter to look
-        # for existing extensions
-        
-        for ext in util.to_list(self._deprecated_extensions or []):
-            ext._adapt_listener(self, ext)
+        if self.inherits:
+            super_extensions = set(chain(*[m._deprecated_extensions 
+                                    for m in self.inherits.iterate_to_root()]))
+        else:
+            super_extensions = set()
+
+        for ext in self._deprecated_extensions:
+            if ext not in super_extensions:
+                ext._adapt_listener(self, ext)
         
         if self.inherits:
-            self.dispatch.update(self.inherits.dispatch)
+            self.class_manager.dispatch.update(
+                        self.inherits.class_manager.dispatch)
 
     def _configure_class_instrumentation(self):
         """If this mapper is to be a primary mapper (i.e. the
@@ -398,7 +411,8 @@ class Mapper(object):
         for key, method in util.iterate_attributes(self.class_):
             if isinstance(method, types.FunctionType):
                 if hasattr(method, '__sa_reconstructor__'):
-                    event.listen(method, 'on_load', manager, raw=True)
+                    self._reconstructor = method
+                    event.listen(_event_on_load, 'on_load', manager, raw=True)
                 elif hasattr(method, '__sa_validators__'):
                     for name in method.__sa_validators__:
                         self._validators[name] = method
@@ -2267,17 +2281,22 @@ class Mapper(object):
                         attrs = state.unloaded
                         # allow query.instances to commit the subset of attrs
                         context.partials[state] = (dict_, attrs)  
-
-                    if not populate_instance or \
-                            populate_instance(self, context, row, instance, 
+                    
+                    if populate_instance:
+                        for fn in populate_instance:
+                            ret = fn(self, context, row, state, 
                                 only_load_props=attrs, 
-                                instancekey=identitykey, isnew=isnew) is \
-                                EXT_CONTINUE:
+                                instancekey=identitykey, isnew=isnew)
+                            if ret is not EXT_CONTINUE:
+                                break
+                        else:
+                            populate_state(state, dict_, row, isnew, attrs)
+                    else:
                         populate_state(state, dict_, row, isnew, attrs)
 
             if loaded_instance:
-                state._run_on_load()
-
+                state.manager.dispatch.on_load(state)
+                
             if result is not None:
                 if append_result:
                     for fn in append_result:
@@ -2382,25 +2401,22 @@ def validates(*names):
         return fn
     return wrap
 
+def _event_on_load(state):
+    instrumenting_mapper = state.manager.info[_INSTRUMENTOR]
+    if instrumenting_mapper._reconstructor:
+        instrumenting_mapper._reconstructor(state.obj())
+        
 def _event_on_init(state, args, kwargs):
     """Trigger mapper compilation and run init_instance hooks."""
 
     instrumenting_mapper = state.manager.info[_INSTRUMENTOR]
     # compile() always compiles all mappers
     instrumenting_mapper.compile()
-    instrumenting_mapper.dispatch.on_init_instance(
-        instrumenting_mapper, instrumenting_mapper.class_,
-        state.manager.original_init,
-        state, args, kwargs)
 
 def _event_on_init_failure(state, args, kwargs):
     """Run init_failed hooks."""
 
     instrumenting_mapper = state.manager.info[_INSTRUMENTOR]
-    util.warn_exception(
-        instrumenting_mapper.dispatch.on_init_failed,
-        instrumenting_mapper, instrumenting_mapper.class_,
-        state.manager.original_init, state, args, kwargs)
 
 def _event_on_resurrect(state):
     # re-populate the primary key elements
index 5e1c7ba094a07b1c5bfa504d28ef721798ff3988..710d3213a4b58226ee81ecc8fb6f042186ccbd76 100644 (file)
@@ -1241,7 +1241,7 @@ class Session(object):
             merged_state.commit_all(merged_dict, self.identity_map)  
 
         if new_instance:
-            merged_state._run_on_load()
+            merged_state.manager.dispatch.on_load(merged_state)
         return merged
 
     @classmethod
index bea4ee500a11a402f92eeff56e1fd31fb820578a..3e977a4c9cf360d59acc682fa18f8c85f3c41660 100644 (file)
@@ -142,9 +142,6 @@ class InstanceState(object):
         else:
             return [x]
 
-    def _run_on_load(self):
-        self.manager.dispatch.on_load(self)
-    
     def __getstate__(self):
         d = {'instance':self.obj()}
 
index 141fde9fc65c263d288f9932b1de321d09aaaffe..06c07dc62977981dc321f0910427b04f961821b9 100644 (file)
@@ -122,15 +122,6 @@ class GenerativeQueryTest(_base.MappedTest):
         res2 = query.filter(Foo.bar < 30).filter(Foo.bar > 10)
         assert res2.count() == 19
 
-    @testing.resolve_artifact_names
-    def test_options(self):
-        query = create_session().query(Foo)
-        class ext1(sa.orm.MapperExtension):
-            def populate_instance(self, mapper, selectcontext, row, instance, **flags):
-                instance.TEST = "hello world"
-                return sa.orm.EXT_CONTINUE
-        assert query.options(sa.orm.extension(ext1()))[0].TEST == "hello world"
-
     @testing.resolve_artifact_names
     def test_order_by(self):
         query = create_session().query(Foo)
index 05cf1fd31de175d909ae1dcfb1b710cd0298be02..b6432a39aa462a8b5c7c1dca03e8f3e51882e182 100644 (file)
@@ -194,6 +194,7 @@ class MapperTest(_fixtures.FixtureTest):
     @testing.resolve_artifact_names
     def test_constructor_exc_1(self):
         """Exceptions raised in the mapped class are not masked by sa decorations"""
+        
         ex = AssertionError('oops')
         sess = create_session()
 
@@ -268,27 +269,6 @@ class MapperTest(_fixtures.FixtureTest):
         mapper(Foo, addresses, inherits=User)
         assert getattr(Foo().__class__, 'name').impl is not None
 
-    @testing.resolve_artifact_names
-    def test_extension_collection_frozen(self):
-        class Foo(User):pass
-        m = mapper(User, users)
-        mapper(Order, orders)
-        compile_mappers()
-        mapper(Foo, addresses, inherits=User)
-        ext_list = [AttributeExtension()]
-        m.add_property('somename', column_property(users.c.name, extension=ext_list))
-        m.add_property('orders', relationship(Order, extension=ext_list, backref='user'))
-        assert len(ext_list) == 1
-
-        assert Foo.orders.impl.extensions is User.orders.impl.extensions
-        assert Foo.orders.impl.extensions is not ext_list
-        
-        compile_mappers()
-        assert len(User.somename.impl.extensions) == 1
-        assert len(Foo.somename.impl.extensions) == 1
-        assert len(Foo.orders.impl.extensions) == 3
-        assert len(User.orders.impl.extensions) == 3
-        
 
     @testing.resolve_artifact_names
     def test_compile_on_get_props_1(self):
@@ -1073,16 +1053,19 @@ class MapperTest(_fixtures.FixtureTest):
         class A(object):
             @reconstructor
             def reconstruct(self):
+                assert isinstance(self, A)
                 recon.append('A')
 
         class B(A):
             @reconstructor
             def reconstruct(self):
+                assert isinstance(self, B)
                 recon.append('B')
 
         class C(A):
             @reconstructor
             def reconstruct(self):
+                assert isinstance(self, C)
                 recon.append('C')
 
         mapper(A, users, polymorphic_on=users.c.name,