]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- refactor expire_attributes into two simpler methods
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 21 Dec 2010 16:08:12 +0000 (11:08 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 21 Dec 2010 16:08:12 +0000 (11:08 -0500)
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/state.py
test/orm/test_attributes.py
test/orm/test_extendedattr.py

index 8c1db3f1ae9b9c8ad90f46c6ee39875611ee87bc..ca23edf3c5c730e91db1809224a44835e0d8ce37 100644 (file)
@@ -1891,7 +1891,7 @@ class Mapper(object):
                     [p.key for p in mapper._readonly_props]
                 )
                 if readonly:
-                    sessionlib._expire_state(state, state.dict, readonly)
+                    state.expire_attributes(state.dict, readonly)
 
             # if eager_defaults option is enabled,
             # refresh whatever has been expired.
@@ -1921,7 +1921,7 @@ class Mapper(object):
                 self._set_state_attr_by_column(state, dict_, c, params[c.key])
 
         if postfetch_cols:
-            sessionlib._expire_state(state, state.dict, 
+            state.expire_attributes(state.dict, 
                                 [self._columntoproperty[c].key 
                                 for c in postfetch_cols]
                             )
index 88a8a8ea661e8d6ab85ad4f364420ce33840e2d3..325a94643a777d52bb6cddbcce717fc68cc0c4bf 100644 (file)
@@ -307,16 +307,14 @@ class SessionTransaction(object):
         assert not self.session._deleted
 
         for s in self.session.identity_map.all_states():
-            _expire_state(s, s.dict, None,
-                          instance_dict=self.session.identity_map)
+            s.expire(s.dict, self.session.identity_map._modified)
 
     def _remove_snapshot(self):
         assert self._is_transaction_boundary
 
         if not self.nested and self.session.expire_on_commit:
             for s in self.session.identity_map.all_states():
-                _expire_state(s, s.dict, None,
-                              instance_dict=self.session.identity_map)
+                s.expire(s.dict, self.session.identity_map._modified)
 
     def _connection_for_bind(self, bind):
         self._assert_is_active()
@@ -937,7 +935,7 @@ class Session(object):
 
         """
         for state in self.identity_map.all_states():
-            _expire_state(state, state.dict, None, instance_dict=self.identity_map)
+            state.expire(state.dict, self.identity_map._modified)
 
     def expire(self, instance, attribute_names=None):
         """Expire the attributes on an instance.
@@ -975,9 +973,7 @@ class Session(object):
     def _expire_state(self, state, attribute_names):
         self._validate_persistent(state)
         if attribute_names:
-            _expire_state(state, state.dict, 
-                                attribute_names=attribute_names, 
-                                instance_dict=self.identity_map)
+            state.expire_attributes(state.dict, attribute_names)
         else:
             # pre-fetch the full cascade since the expire is going to
             # remove associations
@@ -991,8 +987,7 @@ class Session(object):
         """Expire a state if persistent, else expunge if pending"""
         
         if state.key:
-            _expire_state(state, state.dict, None, 
-                                instance_dict=self.identity_map)
+            state.expire(state.dict, self.identity_map._modified)
         elif state in self._new:
             self._new.pop(state)
             state.detach()
@@ -1603,8 +1598,6 @@ class Session(object):
 
         return util.IdentitySet(self._new.values())
 
-_expire_state = state.InstanceState.expire_attributes
-    
 _sessions = weakref.WeakValueDictionary()
 
 def make_transient(instance):
index fd55cfcb87bce90c81aa7326458623edbadbd55c..909977cc47c6dc704ac248f4bc0884b8083814dc 100644 (file)
@@ -224,43 +224,35 @@ class InstanceState(object):
         dict_.pop(key, None)
         self.callables[key] = callable_
     
-    def expire_attributes(self, dict_, attribute_names, instance_dict=None):
-        """Expire all or a group of attributes.
-        
-        If all attributes are expired, the "expired" flag is set to True.
-        
-        """
-        # we would like to assert that 'self.key is not None' here, 
-        # but there are many cases where the mapper will expire
-        # a newly persisted instance within the flush, before the
-        # key is assigned, and even cases where the attribute refresh
-        # occurs fully, within the flush(), before this key is assigned.
-        # the key is assigned late within the flush() to assist in
-        # "key switch" bookkeeping scenarios.
-        
-        if attribute_names is None:
-            attribute_names = self.manager.keys()
-            self.expired = True
-            if self.modified:
-                if not instance_dict:
-                    instance_dict = self._instance_dict()
-                    if instance_dict:
-                        instance_dict._modified.discard(self)
-                else:
-                    instance_dict._modified.discard(self)
+    def expire(self, dict_, modified_set):
+        self.expired = True
+        if self.modified:
+            modified_set.discard(self)
 
-            self.modified = False
-            filter_deferred = True
-        else:
-            filter_deferred = False
+        self.modified = False
 
         pending = self.__dict__.get('pending', None)
         mutable_dict = self.mutable_dict
+        self.committed_state.clear()
+        if mutable_dict:
+            mutable_dict.clear()
+        if pending:
+            pending.clear()
         
-        for key in attribute_names:
+        for key in self.manager:
             impl = self.manager[key].impl
             if impl.accepts_scalar_loader and \
-                (not filter_deferred or impl.expire_missing or key in dict_):
+                (impl.expire_missing or key in dict_):
+                self.callables[key] = self
+            dict_.pop(key, None)
+        
+    def expire_attributes(self, dict_, attribute_names):
+        pending = self.__dict__.get('pending', None)
+        mutable_dict = self.mutable_dict
+        
+        for key in attribute_names:
+            impl = self.manager[key].impl
+            if impl.accepts_scalar_loader:
                 self.callables[key] = self
             dict_.pop(key, None)
             
index 91f61c05c05cee6ccf68c743dddb1bb893c4d202..9bbfa98eb95eeb438c5454f3051c194c1ba0ecb6 100644 (file)
@@ -155,21 +155,21 @@ class AttributesTest(_base.ORMTest):
         attributes.register_attribute(Foo, 'b', uselist=False, useobject=False)
 
         f = Foo()
-        attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
+        attributes.instance_state(f).expire(attributes.instance_dict(f), set())
         eq_(f.a, "this is a")
         eq_(f.b, 12)
 
         f.a = "this is some new a"
-        attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
+        attributes.instance_state(f).expire(attributes.instance_dict(f), set())
         eq_(f.a, "this is a")
         eq_(f.b, 12)
 
-        attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
+        attributes.instance_state(f).expire(attributes.instance_dict(f), set())
         f.a = "this is another new a"
         eq_(f.a, "this is another new a")
         eq_(f.b, 12)
 
-        attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
+        attributes.instance_state(f).expire(attributes.instance_dict(f), set())
         eq_(f.a, "this is a")
         eq_(f.b, 12)
 
@@ -177,7 +177,7 @@ class AttributesTest(_base.ORMTest):
         eq_(f.a, None)
         eq_(f.b, 12)
 
-        attributes.instance_state(f).commit_all(attributes.instance_dict(f))
+        attributes.instance_state(f).commit_all(attributes.instance_dict(f), set())
         eq_(f.a, None)
         eq_(f.b, 12)
 
@@ -195,7 +195,7 @@ class AttributesTest(_base.ORMTest):
         attributes.register_attribute(MyTest, 'b', uselist=False, useobject=False)
 
         m = MyTest()
-        attributes.instance_state(m).expire_attributes(attributes.instance_dict(m), None)
+        attributes.instance_state(m).expire(attributes.instance_dict(m), set())
         assert 'a' not in m.__dict__
         m2 = pickle.loads(pickle.dumps(m))
         assert 'a' not in m2.__dict__
index 2eca1ac387dfedfa5de5843d09a5b146ae664b3e..c20cad0da71501ea727aac677c9db3cc1c833519 100644 (file)
@@ -161,21 +161,21 @@ class UserDefinedExtensionTest(_base.ORMTest):
             
             assert Foo in instrumentation.instrumentation_registry._state_finders
             f = Foo()
-            attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
+            attributes.instance_state(f).expire(attributes.instance_dict(f), set())
             eq_(f.a, "this is a")
             eq_(f.b, 12)
 
             f.a = "this is some new a"
-            attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
+            attributes.instance_state(f).expire(attributes.instance_dict(f), set())
             eq_(f.a, "this is a")
             eq_(f.b, 12)
 
-            attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
+            attributes.instance_state(f).expire(attributes.instance_dict(f), set())
             f.a = "this is another new a"
             eq_(f.a, "this is another new a")
             eq_(f.b, 12)
 
-            attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
+            attributes.instance_state(f).expire(attributes.instance_dict(f), set())
             eq_(f.a, "this is a")
             eq_(f.b, 12)