]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fix expunging of orphans with more than one parent
authorAnts Aasma <ants.aasma@gmail.com>
Mon, 10 Mar 2008 20:49:27 +0000 (20:49 +0000)
committerAnts Aasma <ants.aasma@gmail.com>
Mon, 10 Mar 2008 20:49:27 +0000 (20:49 +0000)
- move flush error for orphans from Mapper to UnitOfWork

lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/cascade.py

index 4efa3f1e8a693fa07603d87c49d33f8d1f05784d..ae48f9e55f8d7a8cf8c5433d159fd2b2933f0b97 100644 (file)
@@ -174,22 +174,10 @@ class Mapper(object):
             self.logger.debug("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (not self.non_primary and "|non-primary" or "") + ") " + msg)
 
     def _is_orphan(self, obj):
-        optimistic = has_identity(obj)
         for (key,klass) in self.delete_orphans:
-            if attributes.has_parent(klass, obj, key, optimistic=optimistic):
-               return False
-        else:
-            if self.delete_orphans:
-                if not has_identity(obj):
-                    raise exceptions.FlushError("instance %s is an unsaved, pending instance and is an orphan (is not attached to %s)" %
-                    (
-                        obj,
-                        ", nor ".join(["any parent '%s' instance via that classes' '%s' attribute" % (klass.__name__, key) for (key,klass) in self.delete_orphans])
-                    ))
-                else:
-                    return True
-            else:
+            if attributes.has_parent(klass, obj, key, optimistic=has_identity(obj)):
                 return False
+        return bool(self.delete_orphans)
 
     def get_property(self, key, resolve_synonyms=False, raiseerr=True):
         """return a MapperProperty associated with the given key."""
index 48b6ea4cb33a0a72b1460061dbe1933441a941a0..ae55b4c94324a8328406eb497fcd30b07d8c779a 100644 (file)
@@ -23,7 +23,7 @@ import StringIO, weakref
 from sqlalchemy import util, logging, topological, exceptions
 from sqlalchemy.orm import attributes, interfaces
 from sqlalchemy.orm import util as mapperutil
-from sqlalchemy.orm.mapper import object_mapper, _state_mapper
+from sqlalchemy.orm.mapper import object_mapper, _state_mapper, has_identity
 
 # Load lazily
 object_session = None
@@ -37,23 +37,25 @@ class UOWEventHandler(interfaces.AttributeExtension):
         self.key = key
         self.class_ = class_
         self.cascade = cascade
+    
+    def _target_mapper(self, obj):
+        prop = object_mapper(obj).get_property(self.key)
+        return prop.mapper
 
     def append(self, obj, item, initiator):
         # process "save_update" cascade rules for when an instance is appended to the list of another instance
         sess = object_session(obj)
         if sess:
             if self.cascade.save_update and item not in sess:
-                mapper = object_mapper(obj)
-                prop = mapper.get_property(self.key)
-                ename = prop.mapper.entity_name
-                sess.save_or_update(item, entity_name=ename)
+                sess.save_or_update(item, entity_name=self._target_mapper(obj).entity_name)
 
     def remove(self, obj, item, initiator):
         sess = object_session(obj)
         if sess:
             # expunge pending orphans
             if self.cascade.delete_orphan and item in sess.new:
-                sess.expunge(item)
+                if self._target_mapper(obj)._is_orphan(item):
+                    sess.expunge(item)
 
     def set(self, obj, newvalue, oldvalue, initiator):
         # process "save_update" cascade rules for when an instance is attached to another instance
@@ -62,10 +64,7 @@ class UOWEventHandler(interfaces.AttributeExtension):
         sess = object_session(obj)
         if sess:
             if newvalue is not None and self.cascade.save_update and newvalue not in sess:
-                mapper = object_mapper(obj)
-                prop = mapper.get_property(self.key)
-                ename = prop.mapper.entity_name
-                sess.save_or_update(newvalue, entity_name=ename)
+                sess.save_or_update(newvalue, entity_name=self._target_mapper(obj).entity_name)
             if self.cascade.delete_orphan and oldvalue in sess.new:
                 sess.expunge(oldvalue)
 
@@ -210,7 +209,15 @@ class UnitOfWork(object):
             if state in processed:
                 continue
 
-            flush_context.register_object(state, isdelete=_state_mapper(state)._is_orphan(state.obj()))
+            obj = state.obj()
+            is_orphan = _state_mapper(state)._is_orphan(obj)
+            if is_orphan and not has_identity(obj):
+                raise exceptions.FlushError("instance %s is an unsaved, pending instance and is an orphan (is not attached to %s)" %
+                    (
+                        obj,
+                        ", nor ".join(["any parent '%s' instance via that classes' '%s' attribute" % (klass.__name__, key) for (key,klass) in _state_mapper(state).delete_orphans])
+                    ))
+            flush_context.register_object(state, isdelete=is_orphan)
             processed.add(state)
 
         # put all remaining deletes into the flush context.
index e2b2c61227695789865ab0f106cfdbda956257c0..1abf01de77eeb57d9aa8b67ca86ccc4dd80a96ed 100644 (file)
@@ -415,6 +415,61 @@ class UnsavedOrphansTest2(ORMTest):
         assert items.count().scalar() == 0
         assert attributes.count().scalar() == 0
 
+class UnsavedOrphansTest3(ORMTest):
+    """test not expuning double parents"""
+
+    def define_tables(self, meta):
+        global sales_reps, accounts, customers
+        sales_reps = Table('sales_reps', meta,
+            Column('sales_rep_id', Integer, Sequence('sales_rep_id_seq'), primary_key = True),
+            Column('name', String(50)),
+        )
+        accounts = Table('accounts', meta,
+            Column('account_id', Integer, Sequence('account_id_seq'), primary_key = True),
+            Column('balance', Integer),
+        )
+        customers = Table('customers', meta,
+            Column('customer_id', Integer, Sequence('customer_id_seq'), primary_key = True),
+            Column('name', String(50)),
+            Column('sales_rep_id', Integer, ForeignKey('sales_reps.sales_rep_id')),
+            Column('account_id', Integer, ForeignKey('accounts.account_id')),
+        )
+
+    def test_double_parent_expunge(self):
+        """test that removing a pending item from a collection expunges it from the session."""
+        class Customer(fixtures.Base):
+            pass
+        class Account(fixtures.Base):
+            pass
+        class SalesRep(fixtures.Base):
+            pass
+
+        mapper(Customer, customers)
+        mapper(Account, accounts, properties=dict(
+            customers=relation(Customer, cascade="all,delete-orphan", backref="account")
+        ))
+        mapper(SalesRep, sales_reps, properties=dict(
+            customers=relation(Customer, cascade="all,delete-orphan", backref="sales_rep")
+        ))
+        s = create_session()
+
+        a = Account(balance=0)
+        sr = SalesRep(name="John")
+        [s.save(x) for x in [a,sr]]
+        s.flush()
+        
+        c = Customer(name="Jane")
+
+        a.customers.append(c)
+        sr.customers.append(c)
+        assert c in s
+        
+        a.customers.remove(c)
+        assert c in s, "Should not expunge customer yet, still has one parent"
+
+        sr.customers.remove(c)
+        assert c not in s, "Should expunge customer when both parents are gone"
+
 class DoubleParentOrphanTest(ORMTest):
     """test orphan detection for an entity with two parent relations"""