]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
further refinements to the previous session.expunge() fix
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 10 Jul 2007 21:09:26 +0000 (21:09 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 10 Jul 2007 21:09:26 +0000 (21:09 +0000)
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/session.py

index ddf7d6251c8bb1af496178a65f1c1d6b4fae6858..15e422eec19dc20573c947a051db74bb700c01db 100644 (file)
@@ -380,10 +380,11 @@ class Session(object):
         Cascading will be applied according to the *expunge* cascade
         rule.
         """
-
+        self._validate_persistent(object)
         for c in [object] + list(_object_mapper(object).cascade_iterator('expunge', object)):
-            self.uow._remove_deleted(c)
-            self._unattach(c)
+            if c in self:
+                self.uow._remove_deleted(c)
+                self._unattach(c)
 
     def save(self, object, entity_name=None):
         """Add a transient (unsaved) instance to this ``Session``.
@@ -615,16 +616,9 @@ class Session(object):
             obj._sa_session_id = self.hash_key
 
     def _unattach(self, obj):
-        self._validate_attached(obj)
-        del obj._sa_session_id
-
-    def _validate_attached(self, obj):
-        """Validate that the given object is either pending or
-        persistent within this Session.
-        """
-
         if not self._is_attached(obj):
             raise exceptions.InvalidRequestError("Instance '%s' not attached to this Session" % repr(obj))
+        del obj._sa_session_id
 
     def _validate_persistent(self, obj):
         """Validate that the given object is persistent within this
index 5f28d28e76db06acd6aa17a6208f457247c94e3f..c6b0b2689c13cbdac05a715520c237d63a9e96df 100644 (file)
@@ -108,7 +108,7 @@ class UnitOfWork(object):
     echo = logging.echo_property()
 
     def _remove_deleted(self, obj):
-        if hasattr(obj, "_instance_key") and obj._instance_key in self.identity_map:
+        if hasattr(obj, "_instance_key"):
             del self.identity_map[obj._instance_key]
         try:
             self.deleted.remove(obj)
index d659d834bd2ff47318b05117577749ade2902d6e..7e0229a7c8f59c92f2c242f7af7eec1ebcd77493 100644 (file)
@@ -38,6 +38,21 @@ class SessionTest(AssertMixin):
         s.user_name = 'some other user'
         s.flush()
 
+    def test_expunge_cascade(self):
+        tables.data()
+        mapper(Address, addresses)
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref=backref("user", cascade="all"), cascade="all")
+        })
+        session = create_session()
+        u = session.query(User).filter_by(user_id=7).one()
+
+        # get everything to load in both directions
+        print [a.user for a in u.addresses]
+
+        # then see if expunge fails
+        session.expunge(u)
+        
     def test_transaction(self):
         class User(object):pass
         mapper(User, users)