From: Mike Bayer Date: Fri, 31 Dec 2010 16:46:30 +0000 (-0500) Subject: - add QueryContext to load(), refresh() X-Git-Tag: rel_0_7b1~100 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9d04eaffcc600befce98117eabe82e0ead0dc741;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - add QueryContext to load(), refresh() - add list of attribute names to refresh() - ensure refresh() only called when attributes actually refreshed - tests. [ticket:2011] --- diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index f3bd91efb6..e849dfcf38 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -51,7 +51,7 @@ class Mutable(object): key = attribute.key parent_cls = attribute.class_ - def load(state): + def load(state, *args): """Listen for objects loaded or refreshed. Wrap the target data member's value with @@ -230,7 +230,7 @@ class MutableComposite(object): key = attribute.key parent_cls = attribute.class_ - def load(state): + def load(state, *args): """Listen for objects loaded or refreshed. Wrap the target data member's value with diff --git a/lib/sqlalchemy/orm/deprecated_interfaces.py b/lib/sqlalchemy/orm/deprecated_interfaces.py index b294a8d7d6..26f1645093 100644 --- a/lib/sqlalchemy/orm/deprecated_interfaces.py +++ b/lib/sqlalchemy/orm/deprecated_interfaces.py @@ -83,7 +83,7 @@ class MapperExtension(object): if me_meth is not ls_meth: if meth == 'reconstruct_instance': def go(ls_meth): - def reconstruct(instance): + def reconstruct(instance, ctx): ls_meth(self, instance) return reconstruct event.listen(self.class_manager, 'load', diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 2c3a7559db..cda71f6075 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -155,7 +155,7 @@ class CompositeProperty(DescriptorProperty): def _setup_event_handlers(self): """Establish events that populate/expire the composite attribute.""" - def load_handler(state): + def load_handler(state, *args): dict_ = state.dict if self.key in dict_: diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index bb011e5f75..b45f3ba6b3 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -121,7 +121,7 @@ class InstanceEvents(event.Events): """ - def load(self, target): + def load(self, target, context): """Receive an object instance after it has been created via ``__new__``, and after initial attribute population has occurred. @@ -135,29 +135,59 @@ class InstanceEvents(event.Events): attributes and collections may or may not be loaded or even initialized, depending on what's present in the result rows. + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param context: the :class:`.QueryContext` corresponding to the + current :class:`.Query` in progress. + """ - def refresh(self, target): + def refresh(self, target, context, attrs): """Receive an object instance after one or more attributes have - been refreshed. + been refreshed from a query. - This hook is called after expired attributes have been reloaded. + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param context: the :class:`.QueryContext` corresponding to the + current :class:`.Query` in progress. + :param attrs: iterable collection of attribute names which + were populated, or None if all column-mapped, non-deferred + attributes were populated. """ - def expire(self, target, keys): + def expire(self, target, attrs): """Receive an object instance after its attributes or some subset have been expired. 'keys' is a list of attribute names. If None, the entire state was expired. - + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param attrs: iterable collection of attribute + names which were expired, or None if all attributes were + expired. + """ def resurrect(self, target): """Receive an object instance as it is 'resurrected' from garbage collection, which occurs when a "dirty" state falls - out of scope.""" + out of scope. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + + """ class MapperEvents(event.Events): @@ -412,7 +442,10 @@ class MapperEvents(event.Events): :param row: the result row being handled. This may be an actual :class:`.RowProxy` or may be a dictionary containing :class:`.Column` objects as keys. - :param class\_: the mapped class. + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. :return: When configured with ``retval=True``, a return value of ``EXT_STOP`` will bypass instance population by the mapper. A value of ``EXT_CONTINUE`` indicates that diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index a0265f9a81..5ee9d58c38 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -2269,37 +2269,40 @@ class Mapper(object): else: populate_state(state, dict_, row, isnew, only_load_props) - else: + if loaded_instance: + state.manager.dispatch.load(state, context) + elif isnew: + state.manager.dispatch.refresh(state, context, only_load_props) + + elif state in context.partials or state.unloaded: # populate attributes on non-loading instances which have # been expired # TODO: apply eager loads to un-lazy loaded collections ? - if state in context.partials or state.unloaded: - if state in context.partials: - isnew = False - (d_, attrs) = context.partials[state] - else: - isnew = True - attrs = state.unloaded - # allow query.instances to commit the subset of attrs - context.partials[state] = (dict_, attrs) - - if populate_instance: - for fn in populate_instance: - ret = fn(self, context, row, state, - only_load_props=attrs, - instancekey=identitykey, isnew=isnew) - if ret is not EXT_CONTINUE: - break - else: - populate_state(state, dict_, row, isnew, attrs) + if state in context.partials: + isnew = False + (d_, attrs) = context.partials[state] + else: + isnew = True + attrs = state.unloaded + # allow query.instances to commit the subset of attrs + context.partials[state] = (dict_, attrs) + + if populate_instance: + for fn in populate_instance: + ret = fn(self, context, row, state, + only_load_props=attrs, + instancekey=identitykey, isnew=isnew) + if ret is not EXT_CONTINUE: + break else: populate_state(state, dict_, row, isnew, attrs) + else: + populate_state(state, dict_, row, isnew, attrs) + + if isnew: + state.manager.dispatch.refresh(state, context, attrs) - if loaded_instance: - state.manager.dispatch.load(state) - elif isnew: - state.manager.dispatch.refresh(state) if result is not None: if append_result: @@ -2462,7 +2465,7 @@ def validates(*names): return fn return wrap -def _event_on_load(state): +def _event_on_load(state, ctx): instrumenting_mapper = state.manager.info[_INSTRUMENTOR] if instrumenting_mapper._reconstructor: instrumenting_mapper._reconstructor(state.obj()) diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 90a27aa6a8..b2900c93fb 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -12,7 +12,18 @@ from test.lib.testing import eq_ from test.orm import _base, _fixtures from sqlalchemy import event -class MapperEventsTest(_fixtures.FixtureTest): + +class _RemoveListeners(object): + def teardown(self): + # TODO: need to get remove() functionality + # going + Mapper.dispatch._clear() + ClassManager.dispatch._clear() + Session.dispatch._clear() + super(_RemoveListeners, self).teardown() + + +class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): run_inserts = None @testing.resolve_artifact_names @@ -58,12 +69,6 @@ class MapperEventsTest(_fixtures.FixtureTest): b = B() eq_(canary, [('init_a', b), ('init_b', b),('init_e', b)]) - def teardown(self): - # TODO: need to get remove() functionality - # going - Mapper.dispatch._clear() - ClassManager.dispatch._clear() - super(MapperEventsTest, self).teardown() def listen_all(self, mapper, **kw): canary = [] @@ -223,7 +228,171 @@ class MapperEventsTest(_fixtures.FixtureTest): eq_(canary, [User, Address]) -class SessionEventsTest(_fixtures.FixtureTest): +class LoadTest(_fixtures.FixtureTest): + run_inserts = None + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(User, users) + + @testing.resolve_artifact_names + def _fixture(self): + canary = [] + def load(target, ctx): + canary.append("load") + def refresh(target, ctx, attrs): + canary.append(("refresh", attrs)) + + event.listen(User, "load", load) + event.listen(User, "refresh", refresh) + return canary + + @testing.resolve_artifact_names + def test_just_loaded(self): + canary = self._fixture() + + sess = Session() + + u1 = User(name='u1') + sess.add(u1) + sess.commit() + sess.close() + + sess.query(User).first() + eq_(canary, ['load']) + + @testing.resolve_artifact_names + def test_repeated_rows(self): + canary = self._fixture() + + sess = Session() + + u1 = User(name='u1') + sess.add(u1) + sess.commit() + sess.close() + + sess.query(User).union_all(sess.query(User)).all() + eq_(canary, ['load']) + + + +class RefreshTest(_fixtures.FixtureTest): + run_inserts = None + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(User, users) + + @testing.resolve_artifact_names + def _fixture(self): + canary = [] + def load(target, ctx): + canary.append("load") + def refresh(target, ctx, attrs): + canary.append(("refresh", attrs)) + + event.listen(User, "load", load) + event.listen(User, "refresh", refresh) + return canary + + @testing.resolve_artifact_names + def test_already_present(self): + canary = self._fixture() + + sess = Session() + + u1 = User(name='u1') + sess.add(u1) + sess.flush() + + sess.query(User).first() + eq_(canary, []) + + @testing.resolve_artifact_names + def test_repeated_rows(self): + canary = self._fixture() + + sess = Session() + + u1 = User(name='u1') + sess.add(u1) + sess.commit() + + sess.query(User).union_all(sess.query(User)).all() + eq_(canary, [('refresh', set(['id','name']))]) + + @testing.resolve_artifact_names + def test_via_refresh_state(self): + canary = self._fixture() + + sess = Session() + + u1 = User(name='u1') + sess.add(u1) + sess.commit() + + u1.name + eq_(canary, [('refresh', set(['id','name']))]) + + @testing.resolve_artifact_names + def test_was_expired(self): + canary = self._fixture() + + sess = Session() + + u1 = User(name='u1') + sess.add(u1) + sess.flush() + sess.expire(u1) + + sess.query(User).first() + eq_(canary, [('refresh', set(['id','name']))]) + + @testing.resolve_artifact_names + def test_was_expired_via_commit(self): + canary = self._fixture() + + sess = Session() + + u1 = User(name='u1') + sess.add(u1) + sess.commit() + + sess.query(User).first() + eq_(canary, [('refresh', set(['id','name']))]) + + @testing.resolve_artifact_names + def test_was_expired_attrs(self): + canary = self._fixture() + + sess = Session() + + u1 = User(name='u1') + sess.add(u1) + sess.flush() + sess.expire(u1, ['name']) + + sess.query(User).first() + eq_(canary, [('refresh', set(['name']))]) + + @testing.resolve_artifact_names + def test_populate_existing(self): + canary = self._fixture() + + sess = Session() + + u1 = User(name='u1') + sess.add(u1) + sess.commit() + + sess.query(User).populate_existing().first() + eq_(canary, [('refresh', None)]) + + +class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): run_inserts = None def test_class_listen(self): @@ -491,12 +660,6 @@ class SessionEventsTest(_fixtures.FixtureTest): ] ) - def teardown(self): - # TODO: need to get remove() functionality - # going - Session.dispatch._clear() - super(SessionEventsTest, self).teardown() - class MapperExtensionTest(_fixtures.FixtureTest): diff --git a/test/orm/test_merge.py b/test/orm/test_merge.py index d190db96fc..1eafa59e88 100644 --- a/test/orm/test_merge.py +++ b/test/orm/test_merge.py @@ -20,7 +20,7 @@ class MergeTest(_fixtures.FixtureTest): def load_tracker(self, cls, canary=None): if canary is None: - def canary(instance): + def canary(instance, *args): canary.called += 1 canary.called = 0