]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- session checks more carefully when determining "object X already in another session";
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Nov 2007 20:12:36 +0000 (20:12 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Nov 2007 20:12:36 +0000 (20:12 +0000)
e.g. if you pickle a series of objects and unpickle (i.e. as in a Pylons HTTP session
or similar), they can go into a new session without any conflict
- added stricter checks around session.delete() similar to update()
- shored up some old "validate" stuff in session/uow

CHANGES
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/unitofwork.py
test/ext/activemapper.py
test/orm/cascade.py
test/orm/session.py

diff --git a/CHANGES b/CHANGES
index c05ef2188119746087eef2090f24b057c6b8aae3..81adb1e8e10ea948a1c599ecd82d53447f7bca93 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -53,9 +53,6 @@ CHANGES
       single UPDATE; many-to-many relations on the parent object update properly. 
       [ticket:841]
 
-    - it's an error to session.save() an object which is already persistent
-      [ticket:840]
-
     - behavior of query.options() is now fully based on paths, i.e. an option
       such as eagerload_all('x.y.z.y.x') will apply eagerloading to only
       those paths, i.e. and not 'x.y.x'; eagerload('children.children') applies
@@ -70,8 +67,19 @@ CHANGES
     - Added proxying of save_or_update, __contains__ and __iter__ methods for
       scoped sessions.
 
-    - session.update() raises an error when updating an instance that is already
-      in the session with a different identity.
+  - session API has been solidified:
+  
+    - it's an error to session.save() an object which is already persistent
+      [ticket:840]
+
+    - it's an error to session.delete() an object which is *not* persistent
+
+    - session.update() and session.delete() raise an error when updating/deleting
+      an instance that is already in the session with a different identity.
+      
+    - session checks more carefully when determining "object X already in another session";
+      e.g. if you pickle a series of objects and unpickle (i.e. as in a Pylons HTTP session
+      or similar), they can go into a new session without any conflict
 
 0.4.0
 -----
index e39062b2cdb2f98ebf4daeb178b0573a5dd7b242..6be72ecefee0d30abb0fb984df14a66ba32b3be4 100644 (file)
@@ -826,21 +826,16 @@ class Session(object):
                                                 lambda c, e:self._save_or_update_impl(c, e),
                                                 halt_on=lambda c:c in self)
 
-    def _save_or_update_impl(self, object, entity_name=None):
-        key = getattr(object, '_instance_key', None)
-        if key is None:
-            self._save_impl(object, entity_name=entity_name)
-        else:
-            self._update_impl(object, entity_name=entity_name)
-
     def delete(self, object):
         """Mark the given instance as deleted.
 
         The delete operation occurs upon ``flush()``.
         """
 
-        for c in [object] + list(_object_mapper(object).cascade_iterator('delete', object)):
-            self.uow.register_deleted(c)
+        self._delete_impl(object)
+        for c in list(_object_mapper(object).cascade_iterator('delete', object)):
+            self._delete_impl(c, ignore_transient=True)
+
 
     def merge(self, object, entity_name=None, _recursive=None):
         """Copy the state of the given `object` onto the persistent
@@ -966,7 +961,7 @@ class Session(object):
             self.uow.register_new(obj)
 
     def _update_impl(self, obj, **kwargs):
-        if self._is_attached(obj) and obj not in self.deleted:
+        if obj in self and obj not in self.deleted:
             return
         if not hasattr(obj, '_instance_key'):
             raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(obj))
@@ -974,6 +969,26 @@ class Session(object):
             raise exceptions.InvalidRequestError("Instance '%s' is with key %s already persisted with a different identity" % (mapperutil.instance_str(obj), obj._instance_key))
         self._attach(obj)
 
+    def _save_or_update_impl(self, object, entity_name=None):
+        key = getattr(object, '_instance_key', None)
+        if key is None:
+            self._save_impl(object, entity_name=entity_name)
+        else:
+            self._update_impl(object, entity_name=entity_name)
+
+    def _delete_impl(self, obj, ignore_transient=False):
+        if obj in self and obj in self.deleted:
+            return
+        if not hasattr(obj, '_instance_key'):
+            if ignore_transient:
+                return
+            else:
+                raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(obj))
+        if self.identity_map.get(obj._instance_key, obj) is not obj:
+            raise exceptions.InvalidRequestError("Instance '%s' is with key %s already persisted with a different identity" % (mapperutil.instance_str(obj), obj._instance_key))
+        self._attach(obj)
+        self.uow.register_deleted(obj)
+
     def _register_persistent(self, obj):
         obj._sa_session_id = self.hash_key
         self.identity_map[obj._instance_key] = obj
@@ -982,39 +997,26 @@ class Session(object):
     def _attach(self, obj):
         old_id = getattr(obj, '_sa_session_id', None)
         if old_id != self.hash_key:
-            if old_id is not None and old_id in _sessions:
+            if old_id is not None and old_id in _sessions and obj in _sessions[old_id]:
                 raise exceptions.InvalidRequestError("Object '%s' is already attached "
                                                      "to session '%s' (this is '%s')" %
                                                      (mapperutil.instance_str(obj), old_id, id(self)))
 
-                # auto-removal from the old session is disabled.  but if we decide to
-                # turn it back on, do it as below: gingerly since _sessions is a WeakValueDict
-                # and it might be affected by other threads
-                #try:
-                #    sess = _sessions[old]
-                #except KeyError:
-                #    sess = None
-                #if sess is not None:
-                #    sess.expunge(old)
             key = getattr(obj, '_instance_key', None)
             if key is not None:
                 self.identity_map[key] = obj
             obj._sa_session_id = self.hash_key
-
+        
     def _unattach(self, obj):
-        if not self._is_attached(obj):
-            raise exceptions.InvalidRequestError("Instance '%s' not attached to this Session" % mapperutil.instance_str(obj))
-        del obj._sa_session_id
+        if obj._sa_session_id == self.hash_key:
+            del obj._sa_session_id
 
     def _validate_persistent(self, obj):
         """Validate that the given object is persistent within this
         ``Session``.
         """
-
-        self.uow._validate_obj(obj)
-
-    def _is_attached(self, obj):
-        return getattr(obj, '_sa_session_id', None) == self.hash_key
+        
+        return obj in self
 
     def __contains__(self, obj):
         """return True if the given object is associated with this session.
@@ -1023,7 +1025,7 @@ class Session(object):
         result of True.
         """
         
-        return self._is_attached(obj) and (obj in self.uow.new or obj._instance_key in self.identity_map)
+        return obj in self.uow.new or (hasattr(obj, '_instance_key') and self.identity_map.get(obj._instance_key) is obj)
 
     def __iter__(self):
         """return an iterator of all objects which are pending or persistent within this Session."""
@@ -1090,7 +1092,9 @@ def object_session(obj):
 
     hashkey = getattr(obj, '_sa_session_id', None)
     if hashkey is not None:
-        return _sessions.get(hashkey)
+        sess = _sessions.get(hashkey)
+        if obj in sess:
+            return sess
     return None
 
 # Lazy initialization to avoid circular imports
index 43f0d46d915fa3ca9615914995273ef3c0316c41..0ce354d6f3b673a608f0b9ea849efcbba8f573f6 100644 (file)
@@ -109,11 +109,6 @@ class UnitOfWork(object):
         except KeyError:
             pass
 
-    def _validate_obj(self, obj):
-        if (hasattr(obj, '_instance_key') and obj._instance_key not in self.identity_map) or \
-            (not hasattr(obj, '_instance_key') and obj not in self.new):
-            raise exceptions.InvalidRequestError("Instance '%s' is not attached or pending within this session" % repr(obj))
-
     def _is_valid(self, obj):
         if (hasattr(obj, '_instance_key') and obj._instance_key not in self.identity_map) or \
             (not hasattr(obj, '_instance_key') and obj not in self.new):
@@ -147,9 +142,7 @@ class UnitOfWork(object):
     def register_deleted(self, obj):
         """register the given persistent object as 'to be deleted' within this unit of work."""
         
-        if obj not in self.deleted:
-            self._validate_obj(obj)
-            self.deleted.add(obj)
+        self.deleted.add(obj)
 
     def locate_dirty(self):
         """return a set of all persistent instances within this unit of work which 
index ca466170e38b1c50ea0aad15bc58bced245c7fb6..860f52813e0b71806cd9280c245b117b0c1ddf3d 100644 (file)
@@ -237,6 +237,7 @@ class testcase(PersistTest):
         # uses a function which I dont think existed when you first wrote ActiveMapper.
         p1 = self.create_person_one()
         self.assertEquals(p1.preferences.person, p1)
+        objectstore.flush()
         objectstore.delete(p1)
         
         objectstore.flush()
index e24fbbdbab95655b7b3b168101ddd5fc9c308ca4..adbb95265113be933735361c19792ceb1fd0329f 100644 (file)
@@ -384,6 +384,7 @@ class UnsavedOrphansTest(ORMTest):
 
         u = User()
         s.save(u)
+        s.flush()
         a = Address()
         assert a not in s.new
         u.addresses.append(a)
@@ -394,7 +395,6 @@ class UnsavedOrphansTest(ORMTest):
             assert False
         except exceptions.FlushError:
             assert True
-        assert u.user_id is None, "Error: user should not be persistent"
         assert a.address_id is None, "Error: address should not be persistent"
 
 
index 530380b8108249e71a125a4193b9b266ef8e73cc..0639a7d9501a2f612876dcae0c9a64690da4917a 100644 (file)
@@ -7,6 +7,7 @@ from sqlalchemy.orm.session import Session as SessionCls
 from testlib import *
 from testlib.tables import *
 from testlib import fixtures, tables
+import pickle
 
 class SessionTest(AssertMixin):
     def setUpAll(self):
@@ -396,7 +397,7 @@ class SessionTest(AssertMixin):
         
         
     @engines.close_open_connections
-    def test_save_update(self):
+    def test_save_update_delete(self):
         
         s = create_session()
         class User(object):
@@ -410,10 +411,16 @@ class SessionTest(AssertMixin):
             assert False
         except exceptions.InvalidRequestError, e:
             assert str(e) == "Instance 'User@%s' is not persisted" % hex(id(user))
+
+        try:
+            s.delete(user)
+            assert False
+        except exceptions.InvalidRequestError, e:
+            assert str(e) == "Instance 'User@%s' is not persisted" % hex(id(user))
             
         s.save(user)
         s.flush()
-        user = s.query(User).selectone()
+        user = s.query(User).one()
         s.expunge(user)
         assert user not in s
         
@@ -424,7 +431,8 @@ class SessionTest(AssertMixin):
         assert user in s.dirty
         s.flush()
         s.clear()
-        user = s.query(User).selectone()
+        assert s.query(User).count() == 1
+        user = s.query(User).one()
         assert user.user_name == 'fred'
         
         # ensure its not dirty if no changes occur
@@ -440,6 +448,25 @@ class SessionTest(AssertMixin):
         except exceptions.InvalidRequestError, e:
             assert str(e) == "Instance 'User@%s' is already persistent" % hex(id(user))
     
+        s2 = create_session()
+        try:
+            s2.delete(user)
+            assert False
+        except exceptions.InvalidRequestError, e:
+            assert "is already attached to session" in str(e)
+            
+        u2 = s2.query(User).get(user.user_id)
+        try:
+            s.delete(u2)
+            assert False
+        except exceptions.InvalidRequestError, e:
+            assert "already persisted with a different identity" in str(e)
+    
+        s.delete(user)
+        s.flush()
+        assert user not in s
+        assert s.query(User).count() == 0
+        
     def test_is_modified(self):
         s = create_session()
         class User(object):pass
@@ -514,7 +541,7 @@ class SessionTest(AssertMixin):
         # save user
         s.save(User())
         s.flush()
-        user = s.query(User).selectone()
+        user = s.query(User).one()
         user = None
         print s.identity_map
         import gc
@@ -585,8 +612,8 @@ class SessionTest(AssertMixin):
         assert a not in s
         s.flush()
         s.clear()
-        assert s.query(User).selectone().user_id == u.user_id
-        assert s.query(Address).selectfirst() is None
+        assert s.query(User).one().user_id == u.user_id
+        assert s.query(Address).first() is None
         
         clear_mappers()
         
@@ -605,8 +632,8 @@ class SessionTest(AssertMixin):
         assert a in s
         s.flush()
         s.clear()
-        assert s.query(Address).selectone().address_id == a.address_id
-        assert s.query(User).selectfirst() is None
+        assert s.query(Address).one().address_id == a.address_id
+        assert s.query(User).first() is None
 
     def _assert_key(self, got, expect):
         assert got == expect, "expected %r got %r" % (expect, got)
@@ -681,7 +708,25 @@ class SessionTest(AssertMixin):
         log = []
         sess.commit()
         assert log == ['before_commit', 'after_commit']
-
+    
+    def test_pickled_update(self):
+        mapper(User, users)
+        sess1 = create_session()
+        sess2 = create_session()
+        
+        u1 = User()
+        sess1.save(u1)
+        
+        try:
+            sess2.save(u1)
+            assert False
+        except exceptions.InvalidRequestError, e:
+            assert "already attached to session" in str(e)
+            
+        u2 = pickle.loads(pickle.dumps(u1))
+        
+        sess2.save(u2)
+        
     def test_duplicate_update(self):
         mapper(User, users)
         Session = sessionmaker()