]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- adapted MapperExtensionTest into MapperEventsTest
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 9 Nov 2010 21:06:34 +0000 (16:06 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 9 Nov 2010 21:06:34 +0000 (16:06 -0500)
lib/sqlalchemy/orm/events.py
test/orm/test_mapper.py

index 205ec32355af33d92f781d965cadf170c2c45f23..4007ffc6620c3833d74858a507155dfcb934e801 100644 (file)
@@ -69,12 +69,15 @@ class InstanceEvents(event.Events):
             return target
         elif isinstance(target, Mapper):
             return target.class_manager
-        elif target is Mapper or target is mapper:
+        elif target is mapper:
             return ClassManager
         elif isinstance(target, type):
-            manager = manager_of_class(target)
-            if manager:
-                return manager
+            if issubclass(target, Mapper):
+                return ClassManager
+            else:
+                manager = manager_of_class(target)
+                if manager:
+                    return manager
         return None
     
     @classmethod
@@ -178,7 +181,10 @@ class MapperEvents(event.Events):
         if target is mapper:
             return Mapper
         elif isinstance(target, type):
-            return class_mapper(target)
+            if issubclass(target, Mapper):
+                return target
+            else:
+                return class_mapper(target)
         else:
             return target
         
@@ -211,7 +217,7 @@ class MapperEvents(event.Events):
             for mapper in target.self_and_descendants:
                 event.Events.listen(fn, identifier, mapper, propagate=True)
         else:
-            event.Events.listen(fn, identifier, self)
+            event.Events.listen(fn, identifier, target)
         
     def on_instrument_class(self, mapper, class_):
         """Receive a class when the mapper is first constructed, and has
@@ -470,10 +476,6 @@ class AttributeEvents(event.Events):
     def remove(cls, fn, identifier, target):
         raise NotImplementedError("Removal of attribute events not yet implemented")
         
-    @classmethod
-    def unwrap(cls, identifier, event):
-        return event['value']
-        
     def on_append(self, target, value, initiator):
         """Receive a collection append event.
 
index 9cc25c872864f55569ed38602f9a4a8e2dfd5bc5..9c73a7225c03538b2978a493eea4b1c32cf50442 100644 (file)
@@ -2555,6 +2555,8 @@ class AttributeExtensionTest(_base.MappedTest):
         eq_(ext_msg, ["Ex1 'a1'", "Ex1 'b1'", "Ex2 'c1'", "Ex1 'a2'", "Ex1 'b2'", "Ex2 'c2'"])
         
 class MapperEventsTest(_fixtures.FixtureTest):
+    run_inserts = None
+    
     @testing.resolve_artifact_names
     def test_instance_event_listen(self):
         """test listen targets for instance events"""
@@ -2603,9 +2605,167 @@ class MapperEventsTest(_fixtures.FixtureTest):
         # going
         Mapper.dispatch.clear()
         ClassManager.dispatch.clear()
+        super(MapperEventsTest, self).teardown()
         
+    def listen_all(self, mapper, **kw):
+        canary = []
+        def evt(meth):
+            def go(*args, **kwargs):
+                canary.append(meth)
+            return go
+            
+        for meth in [
+            'on_init',
+            'on_init_failure',
+            'on_translate_row',
+            'on_create_instance',
+            'on_append_result',
+            'on_populate_instance',
+            'on_load',
+            'on_before_insert',
+            'on_after_insert',
+            'on_before_update',
+            'on_after_update',
+            'on_before_delete',
+            'on_after_delete'
+        ]:
+            event.listen(evt(meth), meth, mapper, **kw)
+        return canary
+
+    @testing.resolve_artifact_names
+    def test_basic(self):
+
+        mapper(User, users)
+        canary = self.listen_all(User)
+        
+        sess = create_session()
+        u = User(name='u1')
+        sess.add(u)
+        sess.flush()
+        u = sess.query(User).populate_existing().get(u.id)
+        sess.expunge_all()
+        u = sess.query(User).get(u.id)
+        u.name = 'u1 changed'
+        sess.flush()
+        sess.delete(u)
+        sess.flush()
+        eq_(canary,
+            ['on_init', 'on_before_insert',
+             'on_after_insert', 'on_translate_row', 'on_populate_instance',
+             'on_append_result', 'on_translate_row', 'on_create_instance',
+             'on_populate_instance', 'on_load', 'on_append_result',
+             'on_before_update', 'on_after_update', 'on_before_delete', 'on_after_delete'])
+
+    @testing.resolve_artifact_names
+    def test_inheritance(self):
+        class AdminUser(User):
+            pass
+
+        mapper(User, users)
+        mapper(AdminUser, addresses, inherits=User)
+
+        canary1 = self.listen_all(User, propagate=True)
+        canary2 = self.listen_all(User)
+        canary3 = self.listen_all(AdminUser)
+
+        sess = create_session()
+        am = AdminUser(name='au1', email_address='au1@e1')
+        sess.add(am)
+        sess.flush()
+        am = sess.query(AdminUser).populate_existing().get(am.id)
+        sess.expunge_all()
+        am = sess.query(AdminUser).get(am.id)
+        am.name = 'au1 changed'
+        sess.flush()
+        sess.delete(am)
+        sess.flush()
+        eq_(canary1, ['on_init', 'on_before_insert', 'on_after_insert',
+            'on_translate_row', 'on_populate_instance',
+            'on_append_result', 'on_translate_row', 'on_create_instance'
+            , 'on_populate_instance', 'on_load', 'on_append_result',
+            'on_before_update', 'on_after_update', 'on_before_delete',
+            'on_after_delete'])
+        eq_(canary2, [])
+        eq_(canary3, ['on_init', 'on_before_insert', 'on_after_insert',
+            'on_translate_row', 'on_populate_instance',
+            'on_append_result', 'on_translate_row', 'on_create_instance'
+            , 'on_populate_instance', 'on_load', 'on_append_result',
+            'on_before_update', 'on_after_update', 'on_before_delete',
+            'on_after_delete'])
+
+    @testing.resolve_artifact_names
+    def test_before_after_only_collection(self):
+        """on_before_update is called on parent for collection modifications,
+        on_after_update is called even if no columns were updated.
+        
+        """
+
+        mapper(Item, items, properties={
+            'keywords': relationship(Keyword, secondary=item_keywords)})
+        mapper(Keyword, keywords)
+        
+        canary1 = self.listen_all(Item)
+        canary2 = self.listen_all(Keyword)
+        
+        sess = create_session()
+        i1 = Item(description="i1")
+        k1 = Keyword(name="k1")
+        sess.add(i1)
+        sess.add(k1)
+        sess.flush()
+        eq_(canary1,
+            ['on_init', 
+            'on_before_insert', 'on_after_insert'])
+        eq_(canary2,
+            ['on_init', 
+            'on_before_insert', 'on_after_insert'])
+
+        canary1[:]= []
+        canary2[:]= []
+
+        i1.keywords.append(k1)
+        sess.flush()
+        eq_(canary1, ['on_before_update', 'on_after_update'])
+        eq_(canary2, [])
+
+        
+    @testing.resolve_artifact_names
+    def test_retval(self):
+        def create_instance(mapper, context, row, class_):
+            u = User.__new__(User)
+            u.foo = True
+            return u
+            
+        mapper(User, users)
+        event.listen(create_instance, 'on_create_instance', 
+                        User, retval=True)
+        sess = create_session()
+        u1 = User()
+        u1.name = 'ed'
+        sess.add(u1)
+        sess.flush()
+        sess.expunge_all()
+        u = sess.query(User).first()
+        assert u.foo
+    
+    @testing.resolve_artifact_names
+    def test_instrument_event(self):
+        canary = []
+        def on_instrument_class(mapper, cls):
+            canary.append(cls)
+            
+        event.listen(on_instrument_class, 'on_instrument_class', Mapper)
+        
+        mapper(User, users)
+        eq_(canary, [User])
+        mapper(Address, addresses)
+        eq_(canary, [User, Address])
+    
         
 class MapperExtensionTest(_fixtures.FixtureTest):
+    """Superceded by MapperEventsTest - test backwards 
+    compatiblity of MapperExtension."""
+    
     run_inserts = None
     
     def extension(self):