From: Mike Bayer Date: Sat, 6 Nov 2010 20:43:48 +0000 (-0400) Subject: - most tests passing on adapted MapperExtension X-Git-Tag: rel_0_7b1~253^2~15 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a87c8f7df266639123990e7e2f4056a257739833;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - most tests passing on adapted MapperExtension --- diff --git a/lib/sqlalchemy/event.py b/lib/sqlalchemy/event.py index 0f6342e6b2..379c3f1dd0 100644 --- a/lib/sqlalchemy/event.py +++ b/lib/sqlalchemy/event.py @@ -68,11 +68,12 @@ class _Dispatch(object): object.""" for ls in other.descriptors: - existing_listeners = getattr(self, ls.name).listeners - existing_listener_set = set(existing_listeners) - existing_listeners.extend([l for l - in ls.listeners - if l not in existing_listener_set]) + getattr(self, ls.name).update(ls) + #existing_listeners = getattr(self, ls.name).listeners + #existing_listener_set = set(existing_listeners) + #existing_listeners.extend([l for l + # in ls.listeners + # if l not in existing_listener_set]) class _EventMeta(type): """Intercept new Event subclasses and create @@ -198,7 +199,17 @@ class _ListenerCollection(object): def __nonzero__(self): return bool(self.listeners or self.parent_listeners) - + + def update(self, other): + """Populate from the listeners in another :class:`_Dispatch` + object.""" + + existing_listeners = self.listeners + existing_listener_set = set(existing_listeners) + existing_listeners.extend([l for l + in other.listeners + if l not in existing_listener_set]) + def append(self, obj, target): if obj not in self.listeners: self.listeners.append(obj) diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 1c5e630cf4..6df13e94f0 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -82,7 +82,6 @@ __all__ = ( 'dynamic_loader', 'eagerload', 'eagerload_all', - 'extension', 'immediateload', 'join', 'joinedload', diff --git a/lib/sqlalchemy/orm/deprecated_interfaces.py b/lib/sqlalchemy/orm/deprecated_interfaces.py index 2145bef4b9..3817dc2eee 100644 --- a/lib/sqlalchemy/orm/deprecated_interfaces.py +++ b/lib/sqlalchemy/orm/deprecated_interfaces.py @@ -1,4 +1,4 @@ -from sqlalchemy import event +from sqlalchemy import event, util from interfaces import EXT_CONTINUE @@ -56,20 +56,44 @@ class MapperExtension(object): @classmethod def _adapt_listener_methods(cls, self, listener, methods): + for meth in methods: me_meth = getattr(MapperExtension, meth) ls_meth = getattr(listener, meth) + # TODO: comparing self.methods to cls.method, # this comparison is probably moot + if me_meth is not ls_meth: if meth == 'reconstruct_instance': def go(ls_meth): def reconstruct(instance): ls_meth(self, instance) return reconstruct - event.listen(go(ls_meth), 'on_load', self.class_manager, raw=False) + event.listen(go(ls_meth), 'on_load', + self.class_manager, raw=False) + elif meth == 'init_instance': + def go(ls_meth): + def init_instance(instance, args, kwargs): + ls_meth(self, self.class_, + self.class_manager.original_init, + instance, args, kwargs) + return init_instance + event.listen(go(ls_meth), 'on_init', + self.class_manager, raw=False) + elif meth == 'init_failed': + def go(ls_meth): + def init_failed(instance, args, kwargs): + util.warn_exception(ls_meth, self, self.class_, + self.class_manager.original_init, + instance, args, kwargs) + + return init_failed + event.listen(go(ls_meth), 'on_init_failure', + self.class_manager, raw=False) else: - event.listen(ls_meth, "on_%s" % meth, self, raw=False, retval=True) + event.listen(ls_meth, "on_%s" % meth, self, + raw=False, retval=True) def instrument_class(self, mapper, class_): diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index c3eab67e14..5765b9ea8b 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -83,13 +83,39 @@ class InstanceEvents(event.Events): raise NotImplementedError("Removal of instance events not yet implemented") def on_init(self, target, args, kwargs): - """""" + """Receive an instance when it's constructor is called. + + This method is only called during a userland construction of + an object. It is not called when an object is loaded from the + database. + + """ def on_init_failure(self, target, args, kwargs): - """""" + """Receive an instance when it's constructor has been called, + and raised an exception. + + This method is only called during a userland construction of + an object. It is not called when an object is loaded from the + database. + + """ def on_load(self, target): - """""" + """Receive an object instance after it has been created via + ``__new__``, and after initial attribute population has + occurred. + + This typically occurs when the instance is created based on + incoming result rows, and is only called once for that + instance's lifetime. + + Note that during a result-row load, this method is called upon + the first row received for this instance. Note that some + attributes and collections may or may not be loaded or even + initialized, depending on what's present in the result rows. + + """ def on_resurrect(self, target): """""" @@ -180,31 +206,6 @@ class MapperEvents(event.Events): """ - def on_init_instance(self, mapper, class_, oldinit, target, args, kwargs): - """Receive an instance when it's constructor is called. - - This method is only called during a userland construction of - an object. It is not called when an object is loaded from the - database. - - The return value is only significant within the ``MapperExtension`` - chain; the parent mapper's behavior isn't modified by this method. - - """ - - def on_init_failed(self, mapper, class_, oldinit, target, args, kwargs): - """Receive an instance when it's constructor has been called, - and raised an exception. - - This method is only called during a userland construction of - an object. It is not called when an object is loaded from the - database. - - The return value is only significant within the ``MapperExtension`` - chain; the parent mapper's behavior isn't modified by this method. - - """ - def on_translate_row(self, mapper, context, row): """Perform pre-processing on the given result row and return a new row instance. @@ -306,25 +307,6 @@ class MapperEvents(event.Events): """ - def on_reconstruct_instance(self, mapper, target): - """Receive an object instance after it has been created via - ``__new__``, and after initial attribute population has - occurred. - - This typically occurs when the instance is created based on - incoming result rows, and is only called once for that - instance's lifetime. - - Note that during a result-row load, this method is called upon - the first row received for this instance. Note that some - attributes and collections may or may not be loaded or even - initialized, depending on what's present in the result rows. - - The return value is only significant within the ``MapperExtension`` - chain; the parent mapper's behavior isn't modified by this method. - - """ - def on_before_insert(self, mapper, connection, target): """Receive an object instance before that instance is inserted into its table. @@ -341,7 +323,6 @@ class MapperEvents(event.Events): """ - def on_after_insert(self, mapper, connection, target): """Receive an object instance after that instance is inserted. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 643ce6faa5..03e313685e 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -141,8 +141,8 @@ class Mapper(object): self._inherits_equated_pairs = None self._memoized_values = {} self._compiled_cache_size = _compiled_cache_size - - self._deprecated_extensions = extension + self._reconstructor = None + self._deprecated_extensions = util.to_list(extension or []) if allow_null_pks: util.warn_deprecated( @@ -322,19 +322,32 @@ class Mapper(object): % self) def _configure_legacy_instrument_class(self): - # TODO: tests failing - for ext in util.to_list(self._deprecated_extensions or []): - ext._adapt_instrument_class(self, ext) + + if self.inherits: + self.dispatch.update(self.inherits.dispatch) + super_extensions = set(chain(*[m._deprecated_extensions + for m in self.inherits.iterate_to_root()])) + else: + super_extensions = set() + + for ext in self._deprecated_extensions: + if ext not in super_extensions: + ext._adapt_instrument_class(self, ext) def _configure_listeners(self): - # TODO: this has to be made smarter to look - # for existing extensions - - for ext in util.to_list(self._deprecated_extensions or []): - ext._adapt_listener(self, ext) + if self.inherits: + super_extensions = set(chain(*[m._deprecated_extensions + for m in self.inherits.iterate_to_root()])) + else: + super_extensions = set() + + for ext in self._deprecated_extensions: + if ext not in super_extensions: + ext._adapt_listener(self, ext) if self.inherits: - self.dispatch.update(self.inherits.dispatch) + self.class_manager.dispatch.update( + self.inherits.class_manager.dispatch) def _configure_class_instrumentation(self): """If this mapper is to be a primary mapper (i.e. the @@ -398,7 +411,8 @@ class Mapper(object): for key, method in util.iterate_attributes(self.class_): if isinstance(method, types.FunctionType): if hasattr(method, '__sa_reconstructor__'): - event.listen(method, 'on_load', manager, raw=True) + self._reconstructor = method + event.listen(_event_on_load, 'on_load', manager, raw=True) elif hasattr(method, '__sa_validators__'): for name in method.__sa_validators__: self._validators[name] = method @@ -2267,17 +2281,22 @@ class Mapper(object): attrs = state.unloaded # allow query.instances to commit the subset of attrs context.partials[state] = (dict_, attrs) - - if not populate_instance or \ - populate_instance(self, context, row, instance, + + if populate_instance: + for fn in populate_instance: + ret = fn(self, context, row, state, only_load_props=attrs, - instancekey=identitykey, isnew=isnew) is \ - EXT_CONTINUE: + instancekey=identitykey, isnew=isnew) + if ret is not EXT_CONTINUE: + break + else: + populate_state(state, dict_, row, isnew, attrs) + else: populate_state(state, dict_, row, isnew, attrs) if loaded_instance: - state._run_on_load() - + state.manager.dispatch.on_load(state) + if result is not None: if append_result: for fn in append_result: @@ -2382,25 +2401,22 @@ def validates(*names): return fn return wrap +def _event_on_load(state): + instrumenting_mapper = state.manager.info[_INSTRUMENTOR] + if instrumenting_mapper._reconstructor: + instrumenting_mapper._reconstructor(state.obj()) + def _event_on_init(state, args, kwargs): """Trigger mapper compilation and run init_instance hooks.""" instrumenting_mapper = state.manager.info[_INSTRUMENTOR] # compile() always compiles all mappers instrumenting_mapper.compile() - instrumenting_mapper.dispatch.on_init_instance( - instrumenting_mapper, instrumenting_mapper.class_, - state.manager.original_init, - state, args, kwargs) def _event_on_init_failure(state, args, kwargs): """Run init_failed hooks.""" instrumenting_mapper = state.manager.info[_INSTRUMENTOR] - util.warn_exception( - instrumenting_mapper.dispatch.on_init_failed, - instrumenting_mapper, instrumenting_mapper.class_, - state.manager.original_init, state, args, kwargs) def _event_on_resurrect(state): # re-populate the primary key elements diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 5e1c7ba094..710d3213a4 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1241,7 +1241,7 @@ class Session(object): merged_state.commit_all(merged_dict, self.identity_map) if new_instance: - merged_state._run_on_load() + merged_state.manager.dispatch.on_load(merged_state) return merged @classmethod diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index bea4ee500a..3e977a4c9c 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -142,9 +142,6 @@ class InstanceState(object): else: return [x] - def _run_on_load(self): - self.manager.dispatch.on_load(self) - def __getstate__(self): d = {'instance':self.obj()} diff --git a/test/orm/test_generative.py b/test/orm/test_generative.py index 141fde9fc6..06c07dc629 100644 --- a/test/orm/test_generative.py +++ b/test/orm/test_generative.py @@ -122,15 +122,6 @@ class GenerativeQueryTest(_base.MappedTest): res2 = query.filter(Foo.bar < 30).filter(Foo.bar > 10) assert res2.count() == 19 - @testing.resolve_artifact_names - def test_options(self): - query = create_session().query(Foo) - class ext1(sa.orm.MapperExtension): - def populate_instance(self, mapper, selectcontext, row, instance, **flags): - instance.TEST = "hello world" - return sa.orm.EXT_CONTINUE - assert query.options(sa.orm.extension(ext1()))[0].TEST == "hello world" - @testing.resolve_artifact_names def test_order_by(self): query = create_session().query(Foo) diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index 05cf1fd31d..b6432a39aa 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -194,6 +194,7 @@ class MapperTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_constructor_exc_1(self): """Exceptions raised in the mapped class are not masked by sa decorations""" + ex = AssertionError('oops') sess = create_session() @@ -268,27 +269,6 @@ class MapperTest(_fixtures.FixtureTest): mapper(Foo, addresses, inherits=User) assert getattr(Foo().__class__, 'name').impl is not None - @testing.resolve_artifact_names - def test_extension_collection_frozen(self): - class Foo(User):pass - m = mapper(User, users) - mapper(Order, orders) - compile_mappers() - mapper(Foo, addresses, inherits=User) - ext_list = [AttributeExtension()] - m.add_property('somename', column_property(users.c.name, extension=ext_list)) - m.add_property('orders', relationship(Order, extension=ext_list, backref='user')) - assert len(ext_list) == 1 - - assert Foo.orders.impl.extensions is User.orders.impl.extensions - assert Foo.orders.impl.extensions is not ext_list - - compile_mappers() - assert len(User.somename.impl.extensions) == 1 - assert len(Foo.somename.impl.extensions) == 1 - assert len(Foo.orders.impl.extensions) == 3 - assert len(User.orders.impl.extensions) == 3 - @testing.resolve_artifact_names def test_compile_on_get_props_1(self): @@ -1073,16 +1053,19 @@ class MapperTest(_fixtures.FixtureTest): class A(object): @reconstructor def reconstruct(self): + assert isinstance(self, A) recon.append('A') class B(A): @reconstructor def reconstruct(self): + assert isinstance(self, B) recon.append('B') class C(A): @reconstructor def reconstruct(self): + assert isinstance(self, C) recon.append('C') mapper(A, users, polymorphic_on=users.c.name,