issued directly by the ORM in the form of UPDATE statements, by setting
the flag "passive_cascades=False".
+ - inheriting mappers now inherit the MapperExtensions of their parent
+ mapper directly, so that all methods for a particular MapperExtension
+ are called for subclasses as well. As always, any MapperExtension
+ can return either EXT_CONTINUE to continue extension processing
+ or EXT_STOP to stop processing. The order of mapper resolution is:
+ <extensions declared on the classes mapper> <extensions declared on the
+ classes' parent mapper> <globally declared extensions>.
+
+ Note that if you instantiate the same extension class separately
+ and then apply it individually for two mappers in the same inheritance
+ chain, the extension will be applied twice to the inheriting class,
+ and each method will be called twice.
+
+ To apply a mapper extension explicitly to each inheriting class but
+ have each method called only once per operation, use the same
+ instance of the extension for both mappers.
+ [ticket:490]
+
- new synonym() behavior: an attribute will be placed on the mapped
class, if one does not exist already, in all cases. if a property
already exists on the class, the synonym will decorate the property
self.__should_log_debug = logging.is_debug_enabled(self.logger)
self._compile_class()
- self._compile_extensions()
self._compile_inheritance()
+ self._compile_extensions()
self._compile_tables()
self._compile_properties()
self._compile_pks()
for ext_obj in util.to_list(extension):
# local MapperExtensions have already instrumented the class
extlist.add(ext_obj)
-
- for ext in global_extensions:
- if isinstance(ext, type):
- ext = ext()
- extlist.add(ext)
- ext.instrument_class(self, self.class_)
+
+ if self.inherits is not None:
+ for ext in self.inherits.extension:
+ if ext not in extlist:
+ extlist.add(ext)
+ ext.instrument_class(self, self.class_)
+ else:
+ for ext in global_extensions:
+ if isinstance(ext, type):
+ ext = ext()
+ if ext not in extlist:
+ extlist.add(ext)
+ ext.instrument_class(self, self.class_)
self.extension = ExtensionCarrier()
for ext in extlist:
if not postupdate:
# call before_XXX extensions
for state, connection, has_identity in tups:
+ mapper = _state_mapper(state)
if not has_identity:
- for mapper in _state_mapper(state).iterate_to_root():
- if 'before_insert' in mapper.extension.methods:
- mapper.extension.before_insert(mapper, connection, state.obj())
+ if 'before_insert' in mapper.extension.methods:
+ mapper.extension.before_insert(mapper, connection, state.obj())
else:
- for mapper in _state_mapper(state).iterate_to_root():
- if 'before_update' in mapper.extension.methods:
- mapper.extension.before_update(mapper, connection, state.obj())
+ if 'before_update' in mapper.extension.methods:
+ mapper.extension.before_update(mapper, connection, state.obj())
for state, connection, has_identity in tups:
# detect if we have a "pending" instance (i.e. has no instance_key attached to it),
if not postupdate:
# call after_XXX extensions
for state, connection in inserted_objects:
- for mapper in _state_mapper(state).iterate_to_root():
- if 'after_insert' in mapper.extension.methods:
- mapper.extension.after_insert(mapper, connection, state.obj())
+ mapper = _state_mapper(state)
+ if 'after_insert' in mapper.extension.methods:
+ mapper.extension.after_insert(mapper, connection, state.obj())
+
for state, connection in updated_objects:
- for mapper in _state_mapper(state).iterate_to_root():
- if 'after_update' in mapper.extension.methods:
- mapper.extension.after_update(mapper, connection, state.obj())
+ mapper = _state_mapper(state)
+ if 'after_update' in mapper.extension.methods:
+ mapper.extension.after_update(mapper, connection, state.obj())
def _postfetch(self, connection, table, state, resultproxy, params, value_params):
"""After an ``INSERT`` or ``UPDATE``, assemble newly generated
tups = [(state, connection) for state in states]
for (state, connection) in tups:
- for mapper in _state_mapper(state).iterate_to_root():
- if 'before_delete' in mapper.extension.methods:
- mapper.extension.before_delete(mapper, connection, state.obj())
+ mapper = _state_mapper(state)
+ if 'before_delete' in mapper.extension.methods:
+ mapper.extension.before_delete(mapper, connection, state.obj())
deleted_objects = util.Set()
table_to_mapper = {}
raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects)))
for state, connection in deleted_objects:
- for mapper in _state_mapper(state).iterate_to_root():
- if 'after_delete' in mapper.extension.methods:
- mapper.extension.after_delete(mapper, connection, state.obj())
+ mapper = _state_mapper(state)
+ if 'after_delete' in mapper.extension.methods:
+ mapper.extension.after_delete(mapper, connection, state.obj())
def _register_dependencies(self, uowcommit):
"""Register ``DependencyProcessor`` instances with a
{'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
)
-class MapperExtensionTest(MapperSuperTest):
+class MapperExtensionTest(PersistTest):
def setUpAll(self):
tables.create()
- def tearDownAll(self):
- tables.drop()
- def tearDown(self):
- clear_mappers()
- tables.delete()
- def setUp(self):
- tables.data()
-
- def test_create_instance(self):
- class Ext(MapperExtension):
- def create_instance(self, *args, **kwargs):
- return User()
- m = mapper(Address, addresses)
- m = mapper(User, users, extension=Ext(), properties = dict(
- addresses = relation(Address, lazy=True),
- ))
-
- q = create_session().query(m)
- l = q.select();
- self.assert_result(l, User, *user_address_result)
-
- def test_methods(self):
- """test that common user-defined methods get called."""
-
- methods = set()
+
+ global methods, Ext
+
+ methods = []
+
class Ext(MapperExtension):
def load(self, query, *args, **kwargs):
- methods.add('load')
+ methods.append('load')
return EXT_CONTINUE
def get(self, query, *args, **kwargs):
- methods.add('get')
+ methods.append('get')
return EXT_CONTINUE
def translate_row(self, mapper, context, row):
- methods.add('translate_row')
+ methods.append('translate_row')
return EXT_CONTINUE
def create_instance(self, mapper, selectcontext, row, class_):
- methods.add('create_instance')
+ methods.append('create_instance')
return EXT_CONTINUE
def append_result(self, mapper, selectcontext, row, instance, result, **flags):
- methods.add('append_result')
+ methods.append('append_result')
return EXT_CONTINUE
def populate_instance(self, mapper, selectcontext, row, instance, **flags):
- methods.add('populate_instance')
+ methods.append('populate_instance')
return EXT_CONTINUE
def before_insert(self, mapper, connection, instance):
- methods.add('before_insert')
+ methods.append('before_insert')
return EXT_CONTINUE
def after_insert(self, mapper, connection, instance):
- methods.add('after_insert')
+ methods.append('after_insert')
return EXT_CONTINUE
def before_update(self, mapper, connection, instance):
- methods.add('before_update')
+ methods.append('before_update')
return EXT_CONTINUE
def after_update(self, mapper, connection, instance):
- methods.add('after_update')
+ methods.append('after_update')
return EXT_CONTINUE
def before_delete(self, mapper, connection, instance):
- methods.add('before_delete')
+ methods.append('before_delete')
return EXT_CONTINUE
def after_delete(self, mapper, connection, instance):
- methods.add('after_delete')
+ methods.append('after_delete')
return EXT_CONTINUE
+ def tearDown(self):
+ clear_mappers()
+ methods[:] = []
+ tables.delete()
+
+ def tearDownAll(self):
+ tables.drop()
+
+ def test_basic(self):
+ """test that common user-defined methods get called."""
mapper(User, users, extension=Ext())
sess = create_session()
u = User()
sess.flush()
sess.delete(u)
sess.flush()
- self.assertEquals(methods, set(['load', 'before_delete', 'create_instance', 'translate_row', 'get',
- 'after_delete', 'after_insert', 'before_update', 'before_insert', 'after_update', 'populate_instance']))
+ self.assertEquals(methods, ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'get',
+ 'translate_row', 'create_instance', 'populate_instance', 'before_update', 'after_update', 'before_delete', 'after_delete'])
+ def test_inheritance(self):
+ # test using inheritance
+ class AdminUser(User):
+ pass
+
+ mapper(User, users, extension=Ext())
+ mapper(AdminUser, addresses, inherits=User)
+
+ sess = create_session()
+ am = AdminUser()
+ sess.save(am)
+ sess.flush()
+ am = sess.query(AdminUser).load(am.user_id)
+ sess.clear()
+ am = sess.query(AdminUser).get(am.user_id)
+ am.user_name = 'foobar'
+ sess.flush()
+ sess.delete(am)
+ sess.flush()
+ self.assertEquals(methods, ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'get',
+ 'translate_row', 'create_instance', 'populate_instance', 'before_update', 'after_update', 'before_delete', 'after_delete'])
+
+ def test_inheritance_with_dupes(self):
+ # test using inheritance, same extension on both mappers
+ class AdminUser(User):
+ pass
+ ext = Ext()
+ mapper(User, users, extension=ext)
+ mapper(AdminUser, addresses, inherits=User, extension=ext)
+
+ sess = create_session()
+ am = AdminUser()
+ sess.save(am)
+ sess.flush()
+ am = sess.query(AdminUser).load(am.user_id)
+ sess.clear()
+ am = sess.query(AdminUser).get(am.user_id)
+ am.user_name = 'foobar'
+ sess.flush()
+ sess.delete(am)
+ sess.flush()
+ self.assertEquals(methods, ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'get',
+ 'translate_row', 'create_instance', 'populate_instance', 'before_update', 'after_update', 'before_delete', 'after_delete'])
+
class RequirementsTest(AssertMixin):
"""Tests the contract for user classes."""