]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
start adding tests to ensure the size of the uow
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 8 Apr 2010 22:52:04 +0000 (18:52 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 8 Apr 2010 22:52:04 +0000 (18:52 -0400)
lib/sqlalchemy/orm/unitofwork.py
test/orm/test_unitofworkv2.py

index 3d5bf8b9463471dfb73bf47fb4fddf5b60f0c625..bdfaa008b273c722b7535384ffdeeac3b6d8a1c9 100644 (file)
@@ -158,9 +158,8 @@ class UOWTransaction(object):
             for state in self.mappers[mapper]:
                 if self.states[state] == checktup:
                     yield state
-                
-    def execute(self):
-        
+    
+    def _generate_actions(self):
         # execute presort_actions, until all states
         # have been processed.   a presort_action might
         # add new states to the uow.
@@ -171,12 +170,12 @@ class UOWTransaction(object):
                     ret = True
             if not ret:
                 break
-        
+
         # see if the graph of mapper dependencies has cycles.
         self.cycles = cycles = topological.find_cycles(
                                         self.dependencies, 
                                         self.postsort_actions.values())
-        
+
         if cycles:
             # if yes, break the per-mapper actions into
             # per-state actions
@@ -202,20 +201,23 @@ class UOWTransaction(object):
                     for dep in convert[edge[1]]:
                         self.dependencies.add((edge[0], dep))
         
-        postsort_actions = set(
+        return set(
                                 [a for a in self.postsort_actions.values()
                                 if not a.disabled
                                 ]
                             ).difference(cycles)
+
+    def execute(self):
+        postsort_actions = self._generate_actions()
         
-        sort = topological.sort(self.dependencies, postsort_actions)
+        #sort = topological.sort(self.dependencies, postsort_actions)
         #print "--------------"
         #print self.dependencies
-        print postsort_actions
-        print "COUNT OF POSTSORT ACTIONS", len(postsort_actions)
+        #print postsort_actions
+        #print "COUNT OF POSTSORT ACTIONS", len(postsort_actions)
         
         # execute
-        if cycles:
+        if self.cycles:
             for set_ in topological.sort_as_subsets(
                                             self.dependencies, 
                                             postsort_actions):
index 120198ac59b814c55534c79e891a065e618e3691..e9183297d27f8a875aa1602ec95b66bc1d995dcd 100644 (file)
@@ -3,7 +3,8 @@ from sqlalchemy.test import testing
 from sqlalchemy.test.schema import Table, Column
 from sqlalchemy import Integer, String, ForeignKey
 from test.orm import _fixtures, _base
-from sqlalchemy.orm import mapper, relationship, backref, create_session
+from sqlalchemy.orm import mapper, relationship, backref, \
+                            create_session, unitofwork, attributes
 from sqlalchemy.test.assertsql import AllOf, CompiledSQL
 
 from test.orm._fixtures import keywords, addresses, Base, Keyword,  \
@@ -15,6 +16,23 @@ from test.orm._fixtures import keywords, addresses, Base, Keyword,  \
 class UOWTest(_fixtures.FixtureTest, testing.AssertsExecutionResults):
     run_inserts = None
 
+    def _assert_uow_size(self,
+        session, 
+        expected
+    ):
+        uow = unitofwork.UOWTransaction(session)
+        deleted = set(session._deleted)
+        new = set(session._new)
+        dirty = set(session._dirty_states).difference(deleted)
+        for s in new.union(dirty):
+            uow.register_object(s)
+        for d in deleted:
+            uow.register_object(d, isdelete=True)
+        postsort_actions = uow._generate_actions()
+        print postsort_actions
+        eq_(len(postsort_actions), expected, postsort_actions)
+    
+
 class RudimentaryFlushTest(UOWTest):
 
     def test_one_to_many_save(self):
@@ -196,6 +214,43 @@ class RudimentaryFlushTest(UOWTest):
                 ),
         )
 
+    def test_o2m_flush_size(self):
+        mapper(User, users, properties={
+            'addresses':relationship(Address),
+        })
+        mapper(Address, addresses)
+
+        sess = create_session()
+        u1 = User(name='ed')
+        sess.add(u1)
+        self._assert_uow_size(sess, 2)
+
+        sess.flush()
+
+        u1.name='jack'
+
+        self._assert_uow_size(sess, 2)
+        sess.flush()
+
+        a1 = Address(email_address='foo')
+        sess.add(a1)
+        sess.flush()
+
+        u1.addresses.append(a1)
+
+        self._assert_uow_size(sess, 6)
+
+        sess.flush()
+
+        sess = create_session()
+        u1 = sess.query(User).first()
+        u1.name='ed'
+        self._assert_uow_size(sess, 2)
+
+        u1.addresses
+        self._assert_uow_size(sess, 6)
+
+
 class SingleCycleTest(UOWTest):
     def test_one_to_many_save(self):
         mapper(Node, nodes, properties={
@@ -389,6 +444,40 @@ class SingleCycleTest(UOWTest):
  #               sess.flush,
  #       )
 
+    def test_singlecycle_flush_size(self):
+        mapper(Node, nodes, properties={
+            'children':relationship(Node)
+        })
+        sess = create_session()
+        n1 = Node(data='ed')
+        sess.add(n1)
+        self._assert_uow_size(sess, 2)
+
+        sess.flush()
+    
+        n1.data='jack'
+
+        self._assert_uow_size(sess, 2)
+        sess.flush()
+    
+        n2 = Node(data='foo')
+        sess.add(n2)
+        sess.flush()
+    
+        n1.children.append(n2)
+
+        self._assert_uow_size(sess, 4)
+    
+        sess.flush()
+    
+        sess = create_session()
+        n1 = sess.query(Node).first()
+        n1.data='ed'
+        self._assert_uow_size(sess, 2)
+    
+        n1.children
+        self._assert_uow_size(sess, 3)
+
 class SingleCycleM2MTest(_base.MappedTest, testing.AssertsExecutionResults):
 
     @classmethod
@@ -538,6 +627,3 @@ class SingleCycleM2MTest(_base.MappedTest, testing.AssertsExecutionResults):
         )
         
         
-        
-        
-        
\ No newline at end of file