From 4a02ef1c091d6ab0ee0cac1f2e42e697d62b5f98 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 18 Jul 2006 15:14:55 +0000 Subject: [PATCH] overhaul to MapperExtension so they arent chained via "next"; this breaks all over the place since extensions get copied between mappers etc. now theyre assembled into a list, of which a single extension can belong to many different lists. --- CHANGES | 2 + lib/sqlalchemy/orm/mapper.py | 115 +++++++++++++++++------------------ 2 files changed, 59 insertions(+), 58 deletions(-) diff --git a/CHANGES b/CHANGES index f27c000ccc..326a8f3c34 100644 --- a/CHANGES +++ b/CHANGES @@ -6,6 +6,8 @@ Existing methods of primary/foreign key creation have not been changed but use these new objects behind the scenes. table creation and reflection is now more table oriented rather than column oriented. [ticket:76] +- overhaul to MapperExtension calling scheme, wasnt working very well +previously - tweaks to ActiveMapper, supports self-referential relationships - slight rearrangement to objectstore (in activemapper/threadlocal) so that the SessionContext is referenced by '.context' instead diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 253410dabf..0eee8f26b3 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -186,6 +186,7 @@ class Mapper(object): # uber-pendantic style of making mapper chain, as various testbase/ # threadlocal/assignmapper combinations keep putting dupes etc. in the list # TODO: do something that isnt 21 lines.... + extlist = util.Set() for ext_class in global_extensions: if isinstance(ext_class, MapperExtension): @@ -198,18 +199,12 @@ class Mapper(object): for ext_obj in util.to_list(extension): extlist.add(ext_obj) - self.extension = None - previous = None + self.extension = ExtensionCarrier() for ext in extlist: - ext.unchain() - if self.extension is None: - self.extension = ext - if previous is not None: - previous.chain(ext) - previous = ext - if self.extension is None: - self.extension = MapperExtension() - + self.extension.elements.append(ext) + + print "EXTENSIONS COMPILED", self.class_, self.extension.elements + def _compile_inheritance(self): """determines if this Mapper inherits from another mapper, and if so calculates the mapped_table for this Mapper taking the inherited mapper into account. for joined table inheritance, creates @@ -714,10 +709,6 @@ class Mapper(object): # already inserted/updated this row but I need you to UPDATE one more # time" isinsert = not postupdate and not hasattr(obj, "_instance_key") - if isinsert: - self.extension.before_insert(self, connection, obj) - else: - self.extension.before_update(self, connection, obj) hasdata = False for col in table.columns: if col is self.version_id_col: @@ -1160,33 +1151,17 @@ class MapperExtension(object): """base implementation for an object that provides overriding behavior to various Mapper functions. For each method in MapperExtension, a result of EXT_PASS indicates the functionality is not overridden.""" - def __init__(self): - self.next = None - def chain(self, ext): - self.next = ext - return self - def unchain(self): - self.next = None def get_session(self): """called to retrieve a contextual Session instance with which to register a new object. Note: this is not called if a session is provided with the __init__ params (i.e. _sa_session)""" - if self.next is None: - return EXT_PASS - else: - return self.next.get_session() + return EXT_PASS def select_by(self, query, *args, **kwargs): """overrides the select_by method of the Query object""" - if self.next is None: - return EXT_PASS - else: - return self.next.select_by(query, *args, **kwargs) + return EXT_PASS def select(self, query, *args, **kwargs): """overrides the select method of the Query object""" - if self.next is None: - return EXT_PASS - else: - return self.next.select(query, *args, **kwargs) + return EXT_PASS def create_instance(self, mapper, session, row, imap, class_): """called when a new object instance is about to be created from a row. the method can choose to create the instance itself, or it can return @@ -1201,10 +1176,7 @@ class MapperExtension(object): class_ - the class we are mapping. """ - if self.next is None: - return EXT_PASS - else: - return self.next.create_instance(mapper, session, row, imap, class_) + return EXT_PASS def append_result(self, mapper, session, row, imap, result, instance, isnew, populate_existing=False): """called when an object instance is being appended to a result list. @@ -1230,10 +1202,7 @@ class MapperExtension(object): populate_existing - usually False, indicates if object instances that were already in the main identity map, i.e. were loaded by a previous select(), get their attributes overwritten """ - if self.next is None: - return EXT_PASS - else: - return self.next.append_result(mapper, session, row, imap, result, instance, isnew, populate_existing) + return EXT_PASS def populate_instance(self, mapper, session, instance, row, identitykey, imap, isnew): """called right before the mapper, after creating an instance from a row, passes the row to its MapperProperty objects which are responsible for populating the object's attributes. @@ -1246,37 +1215,67 @@ class MapperExtension(object): othermapper.populate_instance(session, instance, row, identitykey, imap, isnew, frommapper=mapper) return True """ - if self.next is None: - return EXT_PASS - else: - return self.next.populate_instance(mapper, session, instance, row, identitykey, imap, isnew) + return EXT_PASS def before_insert(self, mapper, connection, instance): """called before an object instance is INSERTed into its table. this is a good place to set up primary key values and such that arent handled otherwise.""" - if self.next is not None: - self.next.before_insert(mapper, connection, instance) + return EXT_PASS def before_update(self, mapper, connection, instance): """called before an object instnace is UPDATED""" - if self.next is not None: - self.next.before_update(mapper, connection, instance) + return EXT_PASS def after_update(self, mapper, connection, instance): """called after an object instnace is UPDATED""" - if self.next is not None: - self.next.after_update(mapper, connection, instance) + return EXT_PASS def after_insert(self, mapper, connection, instance): """called after an object instance has been INSERTed""" - if self.next is not None: - self.next.after_insert(mapper, connection, instance) + return EXT_PASS def before_delete(self, mapper, connection, instance): """called before an object instance is DELETEed""" - if self.next is not None: - self.next.before_delete(mapper, connection, instance) + return EXT_PASS def after_delete(self, mapper, connection, instance): """called after an object instance is DELETEed""" - if self.next is not None: - self.next.after_delete(mapper, connection, instance) + return EXT_PASS +class ExtensionCarrier(MapperExtension): + def __init__(self): + self.elements = [] + # TODO: shrink down this approach using __getattribute__ or similar + def get_session(self): + return self._do('get_session') + def select_by(self, *args, **kwargs): + return self._do('select_by', *args, **kwargs) + def select(self, *args, **kwargs): + return self._do('select', *args, **kwargs) + def create_instance(self, *args, **kwargs): + return self._do('create_instance', *args, **kwargs) + def append_result(self, *args, **kwargs): + return self._do('append_result', *args, **kwargs) + def populate_instance(self, *args, **kwargs): + return self._do('populate_instance', *args, **kwargs) + def before_insert(self, *args, **kwargs): + return self._do('before_insert', *args, **kwargs) + def before_update(self, *args, **kwargs): + return self._do('before_update', *args, **kwargs) + def after_update(self, *args, **kwargs): + return self._do('after_update', *args, **kwargs) + def after_insert(self, *args, **kwargs): + return self._do('after_insert', *args, **kwargs) + def before_delete(self, *args, **kwargs): + return self._do('before_delete', *args, **kwargs) + def after_delete(self, *args, **kwargs): + return self._do('after_delete', *args, **kwargs) + + def _do(self, funcname, *args, **kwargs): + for elem in self.elements: + if elem is self: + raise "WTF" + ret = getattr(elem, funcname)(*args, **kwargs) + if ret is not EXT_PASS: + return ret + else: + return EXT_PASS + class TranslatingDict(dict): """a dictionary that stores ColumnElement objects as keys. incoming ColumnElement keys are translated against those of an underling FromClause for all operations. -- 2.47.2