From: Mike Bayer Date: Thu, 20 Dec 2007 02:37:48 +0000 (+0000) Subject: - inheriting mappers now inherit the MapperExtensions of their parent X-Git-Tag: rel_0_4_2~21 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=2db36bf59c447d4d113cba0ae12f1b739c2ae923;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - inheriting mappers now inherit the MapperExtensions of their parent mapper directly, so that all methods for a particular MapperExtension are called for subclasses as well. As always, any MapperExtension can return either EXT_CONTINUE to continue extension processing or EXT_STOP to stop processing. The order of mapper resolution is: . Note that if you instantiate the same extension class separately and then apply it individually for two mappers in the same inheritance chain, the extension will be applied twice to the inheriting class, and each method will be called twice. To apply a mapper extension explicitly to each inheriting class but have each method called only once per operation, use the same instance of the extension for both mappers. [ticket:490] --- diff --git a/CHANGES b/CHANGES index 85acf228d3..d8075c921d 100644 --- a/CHANGES +++ b/CHANGES @@ -82,6 +82,24 @@ CHANGES issued directly by the ORM in the form of UPDATE statements, by setting the flag "passive_cascades=False". + - inheriting mappers now inherit the MapperExtensions of their parent + mapper directly, so that all methods for a particular MapperExtension + are called for subclasses as well. As always, any MapperExtension + can return either EXT_CONTINUE to continue extension processing + or EXT_STOP to stop processing. The order of mapper resolution is: + . + + Note that if you instantiate the same extension class separately + and then apply it individually for two mappers in the same inheritance + chain, the extension will be applied twice to the inheriting class, + and each method will be called twice. + + To apply a mapper extension explicitly to each inheriting class but + have each method called only once per operation, use the same + instance of the extension for both mappers. + [ticket:490] + - new synonym() behavior: an attribute will be placed on the mapped class, if one does not exist already, in all cases. if a property already exists on the class, the synonym will decorate the property diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 2c3d3ea406..8c375ea392 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -153,8 +153,8 @@ class Mapper(object): self.__should_log_debug = logging.is_debug_enabled(self.logger) self._compile_class() - self._compile_extensions() self._compile_inheritance() + self._compile_extensions() self._compile_tables() self._compile_properties() self._compile_pks() @@ -281,12 +281,19 @@ class Mapper(object): for ext_obj in util.to_list(extension): # local MapperExtensions have already instrumented the class extlist.add(ext_obj) - - for ext in global_extensions: - if isinstance(ext, type): - ext = ext() - extlist.add(ext) - ext.instrument_class(self, self.class_) + + if self.inherits is not None: + for ext in self.inherits.extension: + if ext not in extlist: + extlist.add(ext) + ext.instrument_class(self, self.class_) + else: + for ext in global_extensions: + if isinstance(ext, type): + ext = ext() + if ext not in extlist: + extlist.add(ext) + ext.instrument_class(self, self.class_) self.extension = ExtensionCarrier() for ext in extlist: @@ -960,14 +967,13 @@ class Mapper(object): if not postupdate: # call before_XXX extensions for state, connection, has_identity in tups: + mapper = _state_mapper(state) if not has_identity: - for mapper in _state_mapper(state).iterate_to_root(): - if 'before_insert' in mapper.extension.methods: - mapper.extension.before_insert(mapper, connection, state.obj()) + if 'before_insert' in mapper.extension.methods: + mapper.extension.before_insert(mapper, connection, state.obj()) else: - for mapper in _state_mapper(state).iterate_to_root(): - if 'before_update' in mapper.extension.methods: - mapper.extension.before_update(mapper, connection, state.obj()) + if 'before_update' in mapper.extension.methods: + mapper.extension.before_update(mapper, connection, state.obj()) for state, connection, has_identity in tups: # detect if we have a "pending" instance (i.e. has no instance_key attached to it), @@ -1131,13 +1137,14 @@ class Mapper(object): if not postupdate: # call after_XXX extensions for state, connection in inserted_objects: - for mapper in _state_mapper(state).iterate_to_root(): - if 'after_insert' in mapper.extension.methods: - mapper.extension.after_insert(mapper, connection, state.obj()) + mapper = _state_mapper(state) + if 'after_insert' in mapper.extension.methods: + mapper.extension.after_insert(mapper, connection, state.obj()) + for state, connection in updated_objects: - for mapper in _state_mapper(state).iterate_to_root(): - if 'after_update' in mapper.extension.methods: - mapper.extension.after_update(mapper, connection, state.obj()) + mapper = _state_mapper(state) + if 'after_update' in mapper.extension.methods: + mapper.extension.after_update(mapper, connection, state.obj()) def _postfetch(self, connection, table, state, resultproxy, params, value_params): """After an ``INSERT`` or ``UPDATE``, assemble newly generated @@ -1177,9 +1184,9 @@ class Mapper(object): tups = [(state, connection) for state in states] for (state, connection) in tups: - for mapper in _state_mapper(state).iterate_to_root(): - if 'before_delete' in mapper.extension.methods: - mapper.extension.before_delete(mapper, connection, state.obj()) + mapper = _state_mapper(state) + if 'before_delete' in mapper.extension.methods: + mapper.extension.before_delete(mapper, connection, state.obj()) deleted_objects = util.Set() table_to_mapper = {} @@ -1225,9 +1232,9 @@ class Mapper(object): raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects))) for state, connection in deleted_objects: - for mapper in _state_mapper(state).iterate_to_root(): - if 'after_delete' in mapper.extension.methods: - mapper.extension.after_delete(mapper, connection, state.obj()) + mapper = _state_mapper(state) + if 'after_delete' in mapper.extension.methods: + mapper.extension.after_delete(mapper, connection, state.obj()) def _register_dependencies(self, uowcommit): """Register ``DependencyProcessor`` instances with a diff --git a/test/orm/mapper.py b/test/orm/mapper.py index a1ca19a64f..662ac4a29f 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -1131,83 +1131,73 @@ class NoLoadTest(MapperSuperTest): {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])}, ) -class MapperExtensionTest(MapperSuperTest): +class MapperExtensionTest(PersistTest): def setUpAll(self): tables.create() - def tearDownAll(self): - tables.drop() - def tearDown(self): - clear_mappers() - tables.delete() - def setUp(self): - tables.data() - - def test_create_instance(self): - class Ext(MapperExtension): - def create_instance(self, *args, **kwargs): - return User() - m = mapper(Address, addresses) - m = mapper(User, users, extension=Ext(), properties = dict( - addresses = relation(Address, lazy=True), - )) - - q = create_session().query(m) - l = q.select(); - self.assert_result(l, User, *user_address_result) - - def test_methods(self): - """test that common user-defined methods get called.""" - - methods = set() + + global methods, Ext + + methods = [] + class Ext(MapperExtension): def load(self, query, *args, **kwargs): - methods.add('load') + methods.append('load') return EXT_CONTINUE def get(self, query, *args, **kwargs): - methods.add('get') + methods.append('get') return EXT_CONTINUE def translate_row(self, mapper, context, row): - methods.add('translate_row') + methods.append('translate_row') return EXT_CONTINUE def create_instance(self, mapper, selectcontext, row, class_): - methods.add('create_instance') + methods.append('create_instance') return EXT_CONTINUE def append_result(self, mapper, selectcontext, row, instance, result, **flags): - methods.add('append_result') + methods.append('append_result') return EXT_CONTINUE def populate_instance(self, mapper, selectcontext, row, instance, **flags): - methods.add('populate_instance') + methods.append('populate_instance') return EXT_CONTINUE def before_insert(self, mapper, connection, instance): - methods.add('before_insert') + methods.append('before_insert') return EXT_CONTINUE def after_insert(self, mapper, connection, instance): - methods.add('after_insert') + methods.append('after_insert') return EXT_CONTINUE def before_update(self, mapper, connection, instance): - methods.add('before_update') + methods.append('before_update') return EXT_CONTINUE def after_update(self, mapper, connection, instance): - methods.add('after_update') + methods.append('after_update') return EXT_CONTINUE def before_delete(self, mapper, connection, instance): - methods.add('before_delete') + methods.append('before_delete') return EXT_CONTINUE def after_delete(self, mapper, connection, instance): - methods.add('after_delete') + methods.append('after_delete') return EXT_CONTINUE + def tearDown(self): + clear_mappers() + methods[:] = [] + tables.delete() + + def tearDownAll(self): + tables.drop() + + def test_basic(self): + """test that common user-defined methods get called.""" mapper(User, users, extension=Ext()) sess = create_session() u = User() @@ -1220,10 +1210,54 @@ class MapperExtensionTest(MapperSuperTest): sess.flush() sess.delete(u) sess.flush() - self.assertEquals(methods, set(['load', 'before_delete', 'create_instance', 'translate_row', 'get', - 'after_delete', 'after_insert', 'before_update', 'before_insert', 'after_update', 'populate_instance'])) + self.assertEquals(methods, ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'get', + 'translate_row', 'create_instance', 'populate_instance', 'before_update', 'after_update', 'before_delete', 'after_delete']) + def test_inheritance(self): + # test using inheritance + class AdminUser(User): + pass + + mapper(User, users, extension=Ext()) + mapper(AdminUser, addresses, inherits=User) + + sess = create_session() + am = AdminUser() + sess.save(am) + sess.flush() + am = sess.query(AdminUser).load(am.user_id) + sess.clear() + am = sess.query(AdminUser).get(am.user_id) + am.user_name = 'foobar' + sess.flush() + sess.delete(am) + sess.flush() + self.assertEquals(methods, ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'get', + 'translate_row', 'create_instance', 'populate_instance', 'before_update', 'after_update', 'before_delete', 'after_delete']) + + def test_inheritance_with_dupes(self): + # test using inheritance, same extension on both mappers + class AdminUser(User): + pass + ext = Ext() + mapper(User, users, extension=ext) + mapper(AdminUser, addresses, inherits=User, extension=ext) + + sess = create_session() + am = AdminUser() + sess.save(am) + sess.flush() + am = sess.query(AdminUser).load(am.user_id) + sess.clear() + am = sess.query(AdminUser).get(am.user_id) + am.user_name = 'foobar' + sess.flush() + sess.delete(am) + sess.flush() + self.assertEquals(methods, ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'get', + 'translate_row', 'create_instance', 'populate_instance', 'before_update', 'after_update', 'before_delete', 'after_delete']) + class RequirementsTest(AssertMixin): """Tests the contract for user classes."""