]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
adds some tests, refines out the m2o approach.
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 18 Jun 2010 17:11:19 +0000 (13:11 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 18 Jun 2010 17:11:19 +0000 (13:11 -0400)
lib/sqlalchemy/orm/dependency.py
test/orm/test_cycles.py

index 4d7a038943addce73dcdbc2a4401c91f3ba7f56c..c9a0ce8ed35faa7e01bd6d8915334d464427c92c 100644 (file)
@@ -545,10 +545,11 @@ class ManyToOneDP(DependencyProcessor):
             uow.dependencies.update([
                 (child_saves, after_save),
                 (parent_saves, after_save),
-                (before_delete, parent_deletes),
-                (before_delete, child_deletes),
                 (after_save, child_post_updates),
+
+                (after_save, child_pre_updates),
                 (before_delete, child_pre_updates),
+
                 (child_pre_updates, child_deletes),
             ])
         else:
@@ -567,36 +568,30 @@ class ManyToOneDP(DependencyProcessor):
                                     isdelete, childisdelete):
 
         if self.post_update:
-
+            
             if not isdelete:
-                child_post_updates = unitofwork.PostUpdateThing(
+                parent_post_updates = unitofwork.PostUpdateThing(
                                                     uow, self.parent.primary_base_mapper, False)
                 if childisdelete:
                     uow.dependencies.update([
-                        (save_parent, after_save),
-                        (after_save, child_action), # can remove
-                        
-                        (after_save, child_post_updates),
-                        (child_post_updates, child_action)
+                        (after_save, parent_post_updates),
+                        (parent_post_updates, child_action)
                     ])
                 else:
                     uow.dependencies.update([
                         (save_parent, after_save),
                         (child_action, after_save),
                         
-                        (after_save, child_post_updates)
+                        (after_save, parent_post_updates)
                     ])
             else:
-                child_pre_updates = unitofwork.PostUpdateThing(
+                parent_pre_updates = unitofwork.PostUpdateThing(
                                                     uow, self.parent.primary_base_mapper, True)
 
                 uow.dependencies.update([
-                    (before_delete, delete_parent), # can remove
-                    (before_delete, child_action), # can remove
-                    
-                    (before_delete, child_pre_updates),
-                    (child_pre_updates, delete_parent),
-                    (child_pre_updates, child_action)
+                    (before_delete, parent_pre_updates),
+                    (parent_pre_updates, delete_parent),
+                    (parent_pre_updates, child_action)
                 ])
                     
         elif not isdelete:
index 0327e8a9a1e3eb9587ceffdb4fe5b7d4615b0b08..b41a34aa83753d7a336edc2e344efa27cd4d9d86 100644 (file)
@@ -763,7 +763,10 @@ class OneToManyManyToOneTest(_base.MappedTest):
 
 
 class SelfReferentialPostUpdateTest(_base.MappedTest):
-    """Post_update on a single self-referential mapper"""
+    """Post_update on a single self-referential mapper.
+    
+    
+    """
 
     @classmethod
     def define_tables(cls, metadata):
@@ -785,7 +788,7 @@ class SelfReferentialPostUpdateTest(_base.MappedTest):
                 self.path = path
 
     @testing.resolve_artifact_names
-    def test_basic(self):
+    def test_one(self):
         """Post_update only fires off when needed.
 
         This test case used to produce many superfluous update statements,
@@ -797,7 +800,6 @@ class SelfReferentialPostUpdateTest(_base.MappedTest):
             'children': relationship(
                 Node,
                 primaryjoin=node.c.id==node.c.parent_id,
-                lazy='select',
                 cascade="all",
                 backref=backref("parent", remote_side=node.c.id)
             ),
@@ -805,13 +807,11 @@ class SelfReferentialPostUpdateTest(_base.MappedTest):
                 Node,
                 primaryjoin=node.c.prev_sibling_id==node.c.id,
                 remote_side=node.c.id,
-                lazy='select',
                 uselist=False),
             'next_sibling': relationship(
                 Node,
                 primaryjoin=node.c.next_sibling_id==node.c.id,
                 remote_side=node.c.id,
-                lazy='select',
                 uselist=False,
                 post_update=True)})
 
@@ -849,6 +849,7 @@ class SelfReferentialPostUpdateTest(_base.MappedTest):
         session.flush()
 
         remove_child(root, cats)
+        
         # pre-trigger lazy loader on 'cats' to make the test easier
         cats.children
         self.assert_sql_execution(
@@ -956,3 +957,67 @@ class SelfReferentialPostUpdateTest2(_base.MappedTest):
         assert f2.foo is f1
 
 
+class SelfReferentialPostUpdateTest3(_base.MappedTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('parent', metadata,
+              Column('id', Integer, primary_key=True,
+                     test_needs_autoincrement=True),
+              Column('name', String(50), nullable=False),
+              Column('child_id', Integer,
+                     ForeignKey('child.id', use_alter=True, name='c1'), nullable=True))
+
+        Table('child', metadata,
+           Column('id', Integer, primary_key=True,
+                  test_needs_autoincrement=True),
+           Column('name', String(50), nullable=False),
+           Column('child_id', Integer,
+                  ForeignKey('child.id')),
+           Column('parent_id', Integer,
+                  ForeignKey('parent.id'), nullable=True))
+
+    @classmethod
+    def setup_classes(cls):
+        class Parent(_base.BasicEntity):
+            def __init__(self, name=''):
+                self.name = name
+
+        class Child(_base.BasicEntity):
+            def __init__(self, name=''):
+                self.name = name
+
+    @testing.resolve_artifact_names
+    def test_one(self):
+        mapper(Parent, parent, properties={
+            'children':relationship(Child, primaryjoin=parent.c.id==child.c.parent_id),
+            'child':relationship(Child, primaryjoin=parent.c.child_id==child.c.id, post_update=True)
+        })
+        mapper(Child, child, properties={
+            'parent':relationship(Child, remote_side=child.c.id)
+        })
+        
+        session = create_session()
+        p1 = Parent('p1')
+        c1 = Child('c1')
+        c2 = Child('c2')
+        p1.children =[c1, c2]
+        c2.parent = c1
+        p1.child = c2
+        
+        session.add_all([p1, c1, c2])
+        session.flush()
+
+        p2 = Parent('p2')
+        c3 = Child('c3')
+        p2.children = [c3]
+        p2.child = c3
+        session.add(p2)
+        
+        session.delete(c2)
+        p1.children.remove(c2)
+        p1.child = None
+        session.flush()
+        
+        
+        
+