]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- add QueryContext to load(), refresh()
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 31 Dec 2010 16:46:30 +0000 (11:46 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 31 Dec 2010 16:46:30 +0000 (11:46 -0500)
- add list of attribute names to refresh()
- ensure refresh() only called when attributes actually refreshed
- tests.  [ticket:2011]

lib/sqlalchemy/ext/mutable.py
lib/sqlalchemy/orm/deprecated_interfaces.py
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/events.py
lib/sqlalchemy/orm/mapper.py
test/orm/test_events.py
test/orm/test_merge.py

index f3bd91efb67f466203d84b7bdd56501e2869b873..e849dfcf38f671635a2ccf97e78fbe95a579c361 100644 (file)
@@ -51,7 +51,7 @@ class Mutable(object):
         key = attribute.key
         parent_cls = attribute.class_
         
-        def load(state):
+        def load(state, *args):
             """Listen for objects loaded or refreshed.   
             
             Wrap the target data member's value with 
@@ -230,7 +230,7 @@ class MutableComposite(object):
         key = attribute.key
         parent_cls = attribute.class_
         
-        def load(state):
+        def load(state, *args):
             """Listen for objects loaded or refreshed.   
             
             Wrap the target data member's value with 
index b294a8d7d6b3732ec91891b5c362fae3de206843..26f164509362bf7e12b66c3b1af9a3a7637561ff 100644 (file)
@@ -83,7 +83,7 @@ class MapperExtension(object):
             if me_meth is not ls_meth:
                 if meth == 'reconstruct_instance':
                     def go(ls_meth):
-                        def reconstruct(instance):
+                        def reconstruct(instance, ctx):
                             ls_meth(self, instance)
                         return reconstruct
                     event.listen(self.class_manager, 'load', 
index 2c3a7559db4440475058cbc606c45c0be903e446..cda71f607556f941656f23ebe63f4132a1904e60 100644 (file)
@@ -155,7 +155,7 @@ class CompositeProperty(DescriptorProperty):
     def _setup_event_handlers(self):
         """Establish events that populate/expire the composite attribute."""
         
-        def load_handler(state):
+        def load_handler(state, *args):
             dict_ = state.dict
             
             if self.key in dict_:
index bb011e5f75de45c315223a5e64cb08bbd19c19b4..b45f3ba6b397c75121f651459dd2aca7b8165ed6 100644 (file)
@@ -121,7 +121,7 @@ class InstanceEvents(event.Events):
 
         """
     
-    def load(self, target):
+    def load(self, target, context):
         """Receive an object instance after it has been created via
         ``__new__``, and after initial attribute population has
         occurred.
@@ -135,29 +135,59 @@ class InstanceEvents(event.Events):
         attributes and collections may or may not be loaded or even 
         initialized, depending on what's present in the result rows.
 
+        :param target: the mapped instance.  If 
+         the event is configured with ``raw=True``, this will 
+         instead be the :class:`.InstanceState` state-management
+         object associated with the instance.
+        :param context: the :class:`.QueryContext` corresponding to the
+         current :class:`.Query` in progress.
+
         """
 
-    def refresh(self, target):
+    def refresh(self, target, context, attrs):
         """Receive an object instance after one or more attributes have 
-        been refreshed.
+        been refreshed from a query.
         
-        This hook is called after expired attributes have been reloaded.
+        :param target: the mapped instance.  If 
+         the event is configured with ``raw=True``, this will 
+         instead be the :class:`.InstanceState` state-management
+         object associated with the instance.
+        :param context: the :class:`.QueryContext` corresponding to the
+         current :class:`.Query` in progress.
+        :param attrs: iterable collection of attribute names which 
+         were populated, or None if all column-mapped, non-deferred
+         attributes were populated.
         
         """
     
-    def expire(self, target, keys):
+    def expire(self, target, attrs):
         """Receive an object instance after its attributes or some subset
         have been expired.
         
         'keys' is a list of attribute names.  If None, the entire
         state was expired.
-        
+
+        :param target: the mapped instance.  If 
+         the event is configured with ``raw=True``, this will 
+         instead be the :class:`.InstanceState` state-management
+         object associated with the instance.
+        :param attrs: iterable collection of attribute
+         names which were expired, or None if all attributes were 
+         expired.
+         
         """
         
     def resurrect(self, target):
         """Receive an object instance as it is 'resurrected' from 
         garbage collection, which occurs when a "dirty" state falls
-        out of scope."""
+        out of scope.
+        
+        :param target: the mapped instance.  If 
+         the event is configured with ``raw=True``, this will 
+         instead be the :class:`.InstanceState` state-management
+         object associated with the instance.
+        
+        """
 
         
 class MapperEvents(event.Events):
@@ -412,7 +442,10 @@ class MapperEvents(event.Events):
         :param row: the result row being handled.  This may be 
          an actual :class:`.RowProxy` or may be a dictionary containing
          :class:`.Column` objects as keys.
-        :param class\_: the mapped class.
+        :param target: the mapped instance.  If 
+         the event is configured with ``raw=True``, this will 
+         instead be the :class:`.InstanceState` state-management
+         object associated with the instance.
         :return: When configured with ``retval=True``, a return
          value of ``EXT_STOP`` will bypass instance population by
          the mapper. A value of ``EXT_CONTINUE`` indicates that
index a0265f9a8162c308ef20d0a274478a84de867f8c..5ee9d58c38f7fdb679684224421930259ecc1fd5 100644 (file)
@@ -2269,37 +2269,40 @@ class Mapper(object):
                 else:
                     populate_state(state, dict_, row, isnew, only_load_props)
                     
-            else:
+                if loaded_instance:
+                    state.manager.dispatch.load(state, context)
+                elif isnew:
+                    state.manager.dispatch.refresh(state, context, only_load_props)
+
+            elif state in context.partials or state.unloaded:
                 # populate attributes on non-loading instances which have 
                 # been expired
                 # TODO: apply eager loads to un-lazy loaded collections ?
-                if state in context.partials or state.unloaded:
 
-                    if state in context.partials:
-                        isnew = False
-                        (d_, attrs) = context.partials[state]
-                    else:
-                        isnew = True
-                        attrs = state.unloaded
-                        # allow query.instances to commit the subset of attrs
-                        context.partials[state] = (dict_, attrs)  
-                    
-                    if populate_instance:
-                        for fn in populate_instance:
-                            ret = fn(self, context, row, state, 
-                                only_load_props=attrs, 
-                                instancekey=identitykey, isnew=isnew)
-                            if ret is not EXT_CONTINUE:
-                                break
-                        else:
-                            populate_state(state, dict_, row, isnew, attrs)
+                if state in context.partials:
+                    isnew = False
+                    (d_, attrs) = context.partials[state]
+                else:
+                    isnew = True
+                    attrs = state.unloaded
+                    # allow query.instances to commit the subset of attrs
+                    context.partials[state] = (dict_, attrs)  
+                
+                if populate_instance:
+                    for fn in populate_instance:
+                        ret = fn(self, context, row, state, 
+                            only_load_props=attrs, 
+                            instancekey=identitykey, isnew=isnew)
+                        if ret is not EXT_CONTINUE:
+                            break
                     else:
                         populate_state(state, dict_, row, isnew, attrs)
+                else:
+                    populate_state(state, dict_, row, isnew, attrs)
+                
+                if isnew:
+                    state.manager.dispatch.refresh(state, context, attrs)
 
-            if loaded_instance:
-                state.manager.dispatch.load(state)
-            elif isnew:
-                state.manager.dispatch.refresh(state)
                 
             if result is not None:
                 if append_result:
@@ -2462,7 +2465,7 @@ def validates(*names):
         return fn
     return wrap
 
-def _event_on_load(state):
+def _event_on_load(state, ctx):
     instrumenting_mapper = state.manager.info[_INSTRUMENTOR]
     if instrumenting_mapper._reconstructor:
         instrumenting_mapper._reconstructor(state.obj())
index 90a27aa6a86b8a2e56f0cf506c26b226bf811039..b2900c93fb6765b2a7fb712497e7035317683767 100644 (file)
@@ -12,7 +12,18 @@ from test.lib.testing import eq_
 from test.orm import _base, _fixtures
 from sqlalchemy import event
 
-class MapperEventsTest(_fixtures.FixtureTest):
+
+class _RemoveListeners(object):
+    def teardown(self):
+        # TODO: need to get remove() functionality
+        # going
+        Mapper.dispatch._clear()
+        ClassManager.dispatch._clear()
+        Session.dispatch._clear()
+        super(_RemoveListeners, self).teardown()
+    
+
+class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest):
     run_inserts = None
     
     @testing.resolve_artifact_names
@@ -58,12 +69,6 @@ class MapperEventsTest(_fixtures.FixtureTest):
         b = B()
         eq_(canary, [('init_a', b), ('init_b', b),('init_e', b)])
     
-    def teardown(self):
-        # TODO: need to get remove() functionality
-        # going
-        Mapper.dispatch._clear()
-        ClassManager.dispatch._clear()
-        super(MapperEventsTest, self).teardown()
         
     def listen_all(self, mapper, **kw):
         canary = []
@@ -223,7 +228,171 @@ class MapperEventsTest(_fixtures.FixtureTest):
         eq_(canary, [User, Address])
     
 
-class SessionEventsTest(_fixtures.FixtureTest):
+class LoadTest(_fixtures.FixtureTest):
+    run_inserts = None
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        mapper(User, users)
+    
+    @testing.resolve_artifact_names
+    def _fixture(self):
+        canary = []
+        def load(target, ctx):
+            canary.append("load")
+        def refresh(target, ctx, attrs):
+            canary.append(("refresh", attrs))
+        
+        event.listen(User, "load", load)
+        event.listen(User, "refresh", refresh)
+        return canary
+
+    @testing.resolve_artifact_names
+    def test_just_loaded(self):
+        canary = self._fixture()
+        
+        sess = Session()
+        
+        u1 = User(name='u1')
+        sess.add(u1)
+        sess.commit()
+        sess.close()
+        
+        sess.query(User).first()
+        eq_(canary, ['load'])
+
+    @testing.resolve_artifact_names
+    def test_repeated_rows(self):
+        canary = self._fixture()
+        
+        sess = Session()
+        
+        u1 = User(name='u1')
+        sess.add(u1)
+        sess.commit()
+        sess.close()
+        
+        sess.query(User).union_all(sess.query(User)).all()
+        eq_(canary, ['load'])
+    
+    
+    
+class RefreshTest(_fixtures.FixtureTest):
+    run_inserts = None
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        mapper(User, users)
+    
+    @testing.resolve_artifact_names
+    def _fixture(self):
+        canary = []
+        def load(target, ctx):
+            canary.append("load")
+        def refresh(target, ctx, attrs):
+            canary.append(("refresh", attrs))
+        
+        event.listen(User, "load", load)
+        event.listen(User, "refresh", refresh)
+        return canary
+
+    @testing.resolve_artifact_names
+    def test_already_present(self):
+        canary = self._fixture()
+        
+        sess = Session()
+        
+        u1 = User(name='u1')
+        sess.add(u1)
+        sess.flush()
+        
+        sess.query(User).first()
+        eq_(canary, [])
+
+    @testing.resolve_artifact_names
+    def test_repeated_rows(self):
+        canary = self._fixture()
+        
+        sess = Session()
+        
+        u1 = User(name='u1')
+        sess.add(u1)
+        sess.commit()
+        
+        sess.query(User).union_all(sess.query(User)).all()
+        eq_(canary, [('refresh', set(['id','name']))])
+
+    @testing.resolve_artifact_names
+    def test_via_refresh_state(self):
+        canary = self._fixture()
+        
+        sess = Session()
+        
+        u1 = User(name='u1')
+        sess.add(u1)
+        sess.commit()
+        
+        u1.name
+        eq_(canary, [('refresh', set(['id','name']))])
+
+    @testing.resolve_artifact_names
+    def test_was_expired(self):
+        canary = self._fixture()
+        
+        sess = Session()
+        
+        u1 = User(name='u1')
+        sess.add(u1)
+        sess.flush()
+        sess.expire(u1)
+        
+        sess.query(User).first()
+        eq_(canary, [('refresh', set(['id','name']))])
+
+    @testing.resolve_artifact_names
+    def test_was_expired_via_commit(self):
+        canary = self._fixture()
+        
+        sess = Session()
+        
+        u1 = User(name='u1')
+        sess.add(u1)
+        sess.commit()
+        
+        sess.query(User).first()
+        eq_(canary, [('refresh', set(['id','name']))])
+
+    @testing.resolve_artifact_names
+    def test_was_expired_attrs(self):
+        canary = self._fixture()
+        
+        sess = Session()
+        
+        u1 = User(name='u1')
+        sess.add(u1)
+        sess.flush()
+        sess.expire(u1, ['name'])
+        
+        sess.query(User).first()
+        eq_(canary, [('refresh', set(['name']))])
+        
+    @testing.resolve_artifact_names
+    def test_populate_existing(self):
+        canary = self._fixture()
+        
+        sess = Session()
+        
+        u1 = User(name='u1')
+        sess.add(u1)
+        sess.commit()
+        
+        sess.query(User).populate_existing().first()
+        eq_(canary, [('refresh', None)])
+    
+
+class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest):
     run_inserts = None
 
     def test_class_listen(self):
@@ -491,12 +660,6 @@ class SessionEventsTest(_fixtures.FixtureTest):
             ]
         )
         
-    def teardown(self):
-        # TODO: need to get remove() functionality
-        # going
-        Session.dispatch._clear()
-        super(SessionEventsTest, self).teardown()
-
 
         
 class MapperExtensionTest(_fixtures.FixtureTest):
index d190db96fc25a6ae3df48f14b67b0f95d6e9231f..1eafa59e880272fec6fb64c89d4ef2db05929fdc 100644 (file)
@@ -20,7 +20,7 @@ class MergeTest(_fixtures.FixtureTest):
 
     def load_tracker(self, cls, canary=None):
         if canary is None:
-            def canary(instance):
+            def canary(instance, *args):
                 canary.called += 1
             canary.called = 0