From: Mike Bayer Date: Thu, 8 Apr 2010 22:52:04 +0000 (-0400) Subject: start adding tests to ensure the size of the uow X-Git-Tag: rel_0_6_0~35 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=72fb740c22a1c59abd2ca62246d2cc3d5071429e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git start adding tests to ensure the size of the uow --- diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 3d5bf8b946..bdfaa008b2 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -158,9 +158,8 @@ class UOWTransaction(object): for state in self.mappers[mapper]: if self.states[state] == checktup: yield state - - def execute(self): - + + def _generate_actions(self): # execute presort_actions, until all states # have been processed. a presort_action might # add new states to the uow. @@ -171,12 +170,12 @@ class UOWTransaction(object): ret = True if not ret: break - + # see if the graph of mapper dependencies has cycles. self.cycles = cycles = topological.find_cycles( self.dependencies, self.postsort_actions.values()) - + if cycles: # if yes, break the per-mapper actions into # per-state actions @@ -202,20 +201,23 @@ class UOWTransaction(object): for dep in convert[edge[1]]: self.dependencies.add((edge[0], dep)) - postsort_actions = set( + return set( [a for a in self.postsort_actions.values() if not a.disabled ] ).difference(cycles) + + def execute(self): + postsort_actions = self._generate_actions() - sort = topological.sort(self.dependencies, postsort_actions) + #sort = topological.sort(self.dependencies, postsort_actions) #print "--------------" #print self.dependencies - print postsort_actions - print "COUNT OF POSTSORT ACTIONS", len(postsort_actions) + #print postsort_actions + #print "COUNT OF POSTSORT ACTIONS", len(postsort_actions) # execute - if cycles: + if self.cycles: for set_ in topological.sort_as_subsets( self.dependencies, postsort_actions): diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index 120198ac59..e9183297d2 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -3,7 +3,8 @@ from sqlalchemy.test import testing from sqlalchemy.test.schema import Table, Column from sqlalchemy import Integer, String, ForeignKey from test.orm import _fixtures, _base -from sqlalchemy.orm import mapper, relationship, backref, create_session +from sqlalchemy.orm import mapper, relationship, backref, \ + create_session, unitofwork, attributes from sqlalchemy.test.assertsql import AllOf, CompiledSQL from test.orm._fixtures import keywords, addresses, Base, Keyword, \ @@ -15,6 +16,23 @@ from test.orm._fixtures import keywords, addresses, Base, Keyword, \ class UOWTest(_fixtures.FixtureTest, testing.AssertsExecutionResults): run_inserts = None + def _assert_uow_size(self, + session, + expected + ): + uow = unitofwork.UOWTransaction(session) + deleted = set(session._deleted) + new = set(session._new) + dirty = set(session._dirty_states).difference(deleted) + for s in new.union(dirty): + uow.register_object(s) + for d in deleted: + uow.register_object(d, isdelete=True) + postsort_actions = uow._generate_actions() + print postsort_actions + eq_(len(postsort_actions), expected, postsort_actions) + + class RudimentaryFlushTest(UOWTest): def test_one_to_many_save(self): @@ -196,6 +214,43 @@ class RudimentaryFlushTest(UOWTest): ), ) + def test_o2m_flush_size(self): + mapper(User, users, properties={ + 'addresses':relationship(Address), + }) + mapper(Address, addresses) + + sess = create_session() + u1 = User(name='ed') + sess.add(u1) + self._assert_uow_size(sess, 2) + + sess.flush() + + u1.name='jack' + + self._assert_uow_size(sess, 2) + sess.flush() + + a1 = Address(email_address='foo') + sess.add(a1) + sess.flush() + + u1.addresses.append(a1) + + self._assert_uow_size(sess, 6) + + sess.flush() + + sess = create_session() + u1 = sess.query(User).first() + u1.name='ed' + self._assert_uow_size(sess, 2) + + u1.addresses + self._assert_uow_size(sess, 6) + + class SingleCycleTest(UOWTest): def test_one_to_many_save(self): mapper(Node, nodes, properties={ @@ -389,6 +444,40 @@ class SingleCycleTest(UOWTest): # sess.flush, # ) + def test_singlecycle_flush_size(self): + mapper(Node, nodes, properties={ + 'children':relationship(Node) + }) + sess = create_session() + n1 = Node(data='ed') + sess.add(n1) + self._assert_uow_size(sess, 2) + + sess.flush() + + n1.data='jack' + + self._assert_uow_size(sess, 2) + sess.flush() + + n2 = Node(data='foo') + sess.add(n2) + sess.flush() + + n1.children.append(n2) + + self._assert_uow_size(sess, 4) + + sess.flush() + + sess = create_session() + n1 = sess.query(Node).first() + n1.data='ed' + self._assert_uow_size(sess, 2) + + n1.children + self._assert_uow_size(sess, 3) + class SingleCycleM2MTest(_base.MappedTest, testing.AssertsExecutionResults): @classmethod @@ -538,6 +627,3 @@ class SingleCycleM2MTest(_base.MappedTest, testing.AssertsExecutionResults): ) - - - \ No newline at end of file