]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- inheriting mappers now inherit the MapperExtensions of their parent
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 20 Dec 2007 02:37:48 +0000 (02:37 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 20 Dec 2007 02:37:48 +0000 (02:37 +0000)
mapper directly, so that all methods for a particular MapperExtension
are called for subclasses as well.  As always, any MapperExtension
can return either EXT_CONTINUE to continue extension processing
or EXT_STOP to stop processing.  The order of mapper resolution is:
<extensions declared on the classes mapper> <extensions declared on the
classes' parent mapper> <globally declared extensions>.

Note that if you instantiate the same extension class separately
and then apply it individually for two mappers in the same inheritance
chain, the extension will be applied twice to the inheriting class,
and each method will be called twice.

To apply a mapper extension explicitly to each inheriting class but
have each method called only once per operation, use the same
instance of the extension for both mappers.
[ticket:490]

CHANGES
lib/sqlalchemy/orm/mapper.py
test/orm/mapper.py

diff --git a/CHANGES b/CHANGES
index 85acf228d3d8837d10a65cbdc4679661d697dfac..d8075c921d7a674519bf1e290de59cb425c96aa7 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -82,6 +82,24 @@ CHANGES
      issued directly by the ORM in the form of UPDATE statements, by setting
      the flag "passive_cascades=False".
 
+   - inheriting mappers now inherit the MapperExtensions of their parent
+     mapper directly, so that all methods for a particular MapperExtension
+     are called for subclasses as well.  As always, any MapperExtension 
+     can return either EXT_CONTINUE to continue extension processing
+     or EXT_STOP to stop processing.  The order of mapper resolution is:
+     <extensions declared on the classes mapper> <extensions declared on the
+     classes' parent mapper> <globally declared extensions>.
+     
+     Note that if you instantiate the same extension class separately 
+     and then apply it individually for two mappers in the same inheritance 
+     chain, the extension will be applied twice to the inheriting class,
+     and each method will be called twice.
+     
+     To apply a mapper extension explicitly to each inheriting class but
+     have each method called only once per operation, use the same 
+     instance of the extension for both mappers.
+     [ticket:490]
+     
    - new synonym() behavior: an attribute will be placed on the mapped
      class, if one does not exist already, in all cases. if a property
      already exists on the class, the synonym will decorate the property
index 2c3d3ea40614d305bb1f68a43b0ca79c1c32a6a1..8c375ea392c88f806d26604bd6d1721b7fbd8710 100644 (file)
@@ -153,8 +153,8 @@ class Mapper(object):
         self.__should_log_debug = logging.is_debug_enabled(self.logger)
         
         self._compile_class()
-        self._compile_extensions()
         self._compile_inheritance()
+        self._compile_extensions()
         self._compile_tables()
         self._compile_properties()
         self._compile_pks()
@@ -281,12 +281,19 @@ class Mapper(object):
             for ext_obj in util.to_list(extension):
                 # local MapperExtensions have already instrumented the class
                 extlist.add(ext_obj)
-
-        for ext in global_extensions:
-            if isinstance(ext, type):
-                ext = ext()
-            extlist.add(ext)
-            ext.instrument_class(self, self.class_)
+        
+        if self.inherits is not None:
+            for ext in self.inherits.extension:
+                if ext not in extlist:
+                    extlist.add(ext)
+                    ext.instrument_class(self, self.class_)
+        else:
+            for ext in global_extensions:
+                if isinstance(ext, type):
+                    ext = ext()
+                if ext not in extlist:
+                    extlist.add(ext)
+                    ext.instrument_class(self, self.class_)
             
         self.extension = ExtensionCarrier()
         for ext in extlist:
@@ -960,14 +967,13 @@ class Mapper(object):
         if not postupdate:
             # call before_XXX extensions
             for state, connection, has_identity in tups:
+                mapper = _state_mapper(state)
                 if not has_identity:
-                    for mapper in _state_mapper(state).iterate_to_root():
-                        if 'before_insert' in mapper.extension.methods:
-                            mapper.extension.before_insert(mapper, connection, state.obj())
+                    if 'before_insert' in mapper.extension.methods:
+                        mapper.extension.before_insert(mapper, connection, state.obj())
                 else:
-                    for mapper in _state_mapper(state).iterate_to_root():
-                        if 'before_update' in mapper.extension.methods:
-                            mapper.extension.before_update(mapper, connection, state.obj())
+                    if 'before_update' in mapper.extension.methods:
+                        mapper.extension.before_update(mapper, connection, state.obj())
 
         for state, connection, has_identity in tups:
             # detect if we have a "pending" instance (i.e. has no instance_key attached to it),
@@ -1131,13 +1137,14 @@ class Mapper(object):
         if not postupdate:
             # call after_XXX extensions
             for state, connection in inserted_objects:
-                for mapper in _state_mapper(state).iterate_to_root():
-                    if 'after_insert' in mapper.extension.methods:
-                        mapper.extension.after_insert(mapper, connection, state.obj())
+                mapper = _state_mapper(state)
+                if 'after_insert' in mapper.extension.methods:
+                    mapper.extension.after_insert(mapper, connection, state.obj())
+
             for state, connection in updated_objects:
-                for mapper in _state_mapper(state).iterate_to_root():
-                    if 'after_update' in mapper.extension.methods:
-                        mapper.extension.after_update(mapper, connection, state.obj())
+                mapper = _state_mapper(state)
+                if 'after_update' in mapper.extension.methods:
+                    mapper.extension.after_update(mapper, connection, state.obj())
     
     def _postfetch(self, connection, table, state, resultproxy, params, value_params):
         """After an ``INSERT`` or ``UPDATE``, assemble newly generated
@@ -1177,9 +1184,9 @@ class Mapper(object):
             tups = [(state, connection) for state in states]
 
         for (state, connection) in tups:
-            for mapper in _state_mapper(state).iterate_to_root():
-                if 'before_delete' in mapper.extension.methods:
-                    mapper.extension.before_delete(mapper, connection, state.obj())
+            mapper = _state_mapper(state)
+            if 'before_delete' in mapper.extension.methods:
+                mapper.extension.before_delete(mapper, connection, state.obj())
 
         deleted_objects = util.Set()
         table_to_mapper = {}
@@ -1225,9 +1232,9 @@ class Mapper(object):
                     raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects)))
 
         for state, connection in deleted_objects:
-            for mapper in _state_mapper(state).iterate_to_root():
-                if 'after_delete' in mapper.extension.methods:
-                    mapper.extension.after_delete(mapper, connection, state.obj())
+            mapper = _state_mapper(state)
+            if 'after_delete' in mapper.extension.methods:
+                mapper.extension.after_delete(mapper, connection, state.obj())
 
     def _register_dependencies(self, uowcommit):
         """Register ``DependencyProcessor`` instances with a
index a1ca19a64f16bd0c9ded715399fe18e98dac2ccd..662ac4a29f249c5aa7e68a3ab17471a810b9d772 100644 (file)
@@ -1131,83 +1131,73 @@ class NoLoadTest(MapperSuperTest):
             {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
             )
 
-class MapperExtensionTest(MapperSuperTest):
+class MapperExtensionTest(PersistTest):
     def setUpAll(self):
         tables.create()
-    def tearDownAll(self):
-        tables.drop()
-    def tearDown(self):
-        clear_mappers()
-        tables.delete()
-    def setUp(self):
-        tables.data()
-
-    def test_create_instance(self):
-        class Ext(MapperExtension):
-            def create_instance(self, *args, **kwargs):
-                return User()
-        m = mapper(Address, addresses)
-        m = mapper(User, users, extension=Ext(), properties = dict(
-            addresses = relation(Address, lazy=True),
-        ))
-
-        q = create_session().query(m)
-        l = q.select();
-        self.assert_result(l, User, *user_address_result)
-
-    def test_methods(self):
-        """test that common user-defined methods get called."""
-
-        methods = set()
+        
+        global methods, Ext
+        
+        methods = []
+        
         class Ext(MapperExtension):
             def load(self, query, *args, **kwargs):
-                methods.add('load')
+                methods.append('load')
                 return EXT_CONTINUE
 
             def get(self, query, *args, **kwargs):
-                methods.add('get')
+                methods.append('get')
                 return EXT_CONTINUE
 
             def translate_row(self, mapper, context, row):
-                methods.add('translate_row')
+                methods.append('translate_row')
                 return EXT_CONTINUE
 
             def create_instance(self, mapper, selectcontext, row, class_):
-                methods.add('create_instance')
+                methods.append('create_instance')
                 return EXT_CONTINUE
 
             def append_result(self, mapper, selectcontext, row, instance, result, **flags):
-                methods.add('append_result')
+                methods.append('append_result')
                 return EXT_CONTINUE
 
             def populate_instance(self, mapper, selectcontext, row, instance, **flags):
-                methods.add('populate_instance')
+                methods.append('populate_instance')
                 return EXT_CONTINUE
 
             def before_insert(self, mapper, connection, instance):
-                methods.add('before_insert')
+                methods.append('before_insert')
                 return EXT_CONTINUE
 
             def after_insert(self, mapper, connection, instance):
-                methods.add('after_insert')
+                methods.append('after_insert')
                 return EXT_CONTINUE
 
             def before_update(self, mapper, connection, instance):
-                methods.add('before_update')
+                methods.append('before_update')
                 return EXT_CONTINUE
 
             def after_update(self, mapper, connection, instance):
-                methods.add('after_update')
+                methods.append('after_update')
                 return EXT_CONTINUE
 
             def before_delete(self, mapper, connection, instance):
-                methods.add('before_delete')
+                methods.append('before_delete')
                 return EXT_CONTINUE
 
             def after_delete(self, mapper, connection, instance):
-                methods.add('after_delete')
+                methods.append('after_delete')
                 return EXT_CONTINUE
 
+    def tearDown(self):
+        clear_mappers()
+        methods[:] = []
+        tables.delete()
+    
+    def tearDownAll(self):
+        tables.drop()
+            
+    def test_basic(self):
+        """test that common user-defined methods get called."""
         mapper(User, users, extension=Ext())
         sess = create_session()
         u = User()
@@ -1220,10 +1210,54 @@ class MapperExtensionTest(MapperSuperTest):
         sess.flush()
         sess.delete(u)
         sess.flush()
-        self.assertEquals(methods, set(['load', 'before_delete', 'create_instance', 'translate_row', 'get',
-                'after_delete', 'after_insert', 'before_update', 'before_insert', 'after_update', 'populate_instance']))
+        self.assertEquals(methods, ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'get', 
+            'translate_row', 'create_instance', 'populate_instance', 'before_update', 'after_update', 'before_delete', 'after_delete'])
 
+    def test_inheritance(self):
+        # test using inheritance
+        class AdminUser(User):
+            pass
+            
+        mapper(User, users, extension=Ext())
+        mapper(AdminUser, addresses, inherits=User)
+        
+        sess = create_session()
+        am = AdminUser()
+        sess.save(am)
+        sess.flush()
+        am = sess.query(AdminUser).load(am.user_id)
+        sess.clear()
+        am = sess.query(AdminUser).get(am.user_id)
+        am.user_name = 'foobar'
+        sess.flush()
+        sess.delete(am)
+        sess.flush()
+        self.assertEquals(methods, ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'get', 
+            'translate_row', 'create_instance', 'populate_instance', 'before_update', 'after_update', 'before_delete', 'after_delete'])
+
+    def test_inheritance_with_dupes(self):
+        # test using inheritance, same extension on both mappers
+        class AdminUser(User):
+            pass
 
+        ext = Ext()
+        mapper(User, users, extension=ext)
+        mapper(AdminUser, addresses, inherits=User, extension=ext)
+
+        sess = create_session()
+        am = AdminUser()
+        sess.save(am)
+        sess.flush()
+        am = sess.query(AdminUser).load(am.user_id)
+        sess.clear()
+        am = sess.query(AdminUser).get(am.user_id)
+        am.user_name = 'foobar'
+        sess.flush()
+        sess.delete(am)
+        sess.flush()
+        self.assertEquals(methods, ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'get', 
+            'translate_row', 'create_instance', 'populate_instance', 'before_update', 'after_update', 'before_delete', 'after_delete'])
+        
 class RequirementsTest(AssertMixin):
     """Tests the contract for user classes."""