From 037ee9fbf80279b4fc8e8d990860082c2f4dfea5 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 17 Mar 2007 20:46:52 +0000 Subject: [PATCH] - flush fixes on self-referential relationships that contain references to other instances outside of the cyclical chain, when the initial self-referential objects are not actually part of the flush --- CHANGES | 4 ++ lib/sqlalchemy/orm/dependency.py | 1 + lib/sqlalchemy/orm/unitofwork.py | 38 ++++++++--- test/orm/cycles.py | 109 +++++++++++++++++++++++++++++-- 4 files changed, 138 insertions(+), 14 deletions(-) diff --git a/CHANGES b/CHANGES index 9df40b9b3b..49e445aada 100644 --- a/CHANGES +++ b/CHANGES @@ -126,6 +126,10 @@ targeting of columns that belong to the polymorphic union vs. those that dont. + - flush fixes on self-referential relationships that contain references + to other instances outside of the cyclical chain, when the initial + self-referential objects are not actually part of the flush + - put an aggressive check for "flushing object A with a collection of B's, but you put a C in the collection" error condition - **even if C is a subclass of B**, unless B's mapper loads polymorphically. diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index a528c13274..c2d56fde82 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -410,6 +410,7 @@ class MapperStub(object): def __init__(self, parent, mapper, key): self.mapper = mapper + self.class_ = mapper.class_ self._inheriting_mappers = [] def register_dependencies(self, uowcommit): diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 587c405e65..70f0e1d0b3 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -406,7 +406,9 @@ class UOWTransaction(object): mappers = self._get_noninheriting_mappers() head = DependencySorter(self.dependencies, list(mappers)).sort(allow_all_cycles=True) - self.logger.debug("Dependency sort:\n"+ str(head)) + if logging.is_debug_enabled(self.logger): + self.logger.debug("Dependent tuples:\n" + "\n".join(["(%s->%s)" % (d[0].class_.__name__, d[1].class_.__name__) for d in self.dependencies])) + self.logger.debug("Dependency sort:\n"+ str(head)) task = sort_hier(head) return task @@ -726,17 +728,16 @@ class UOWTask(object): get_dependency_task(obj, dep).append(obj, isdelete=isdelete) #print "TUPLES", tuples + #print "ALLOBJECTS", allobjects head = DependencySorter(tuples, allobjects).sort() - if head is None: - return None - - #print str(head) # create a tree of UOWTasks corresponding to the tree of object instances # created by the DependencySorter + used_tasks = util.Set() def make_task_tree(node, parenttask, nexttasks): originating_task = object_to_original_task[node.item] + used_tasks.add(originating_task) t = nexttasks.get(originating_task, None) if t is None: t = UOWTask(self.uowtransaction, originating_task.mapper, circular_parent=self) @@ -755,12 +756,29 @@ class UOWTask(object): # this is the new "circular" UOWTask which will execute in place of "self" t = UOWTask(self.uowtransaction, self.mapper, circular_parent=self) - # stick the non-circular dependencies and child tasks onto the new - # circular UOWTask - [t.dependencies.add(d) for d in extradeplist] + # stick the non-circular dependencies onto the new UOWTask + for d in extradeplist: + t.dependencies.add(d) + + # share the "childtasks" list with the new UOWTask. more elements + # may be appended to this "childtasks" list in the enclosing + # _sort_dependencies() operation that is calling us. t.childtasks = self.childtasks - make_task_tree(head, t, {}) - #print t.dump() + + # if we have a head from the dependency sort, assemble child nodes + # onto the tree. note this only occurs if there were actual objects + # to be saved/deleted. + if head is not None: + make_task_tree(head, t, {}) + + for t2 in cycles: + # tasks that were in the cycle but did not get assembled + # into the tree, add them as child tasks. these tasks + # will have no "save" or "delete" members, but may have dependency + # processors that operate upon other tasks outside of the cycle. + if t2 not in used_tasks and t2 is not self: + t.childtasks.insert(0, t2) + return t def dump(self): diff --git a/test/orm/cycles.py b/test/orm/cycles.py index 2c6a3541cf..69ecbc0426 100644 --- a/test/orm/cycles.py +++ b/test/orm/cycles.py @@ -1,4 +1,4 @@ -from testbase import PersistTest, AssertMixin +from testbase import PersistTest, AssertMixin, ORMTest import unittest, sys, os from sqlalchemy import * import StringIO @@ -8,7 +8,12 @@ from tables import * import tables """test cyclical mapper relationships. Many of the assertions are provided -via running with postgres, which is strict about foreign keys.""" +via running with postgres, which is strict about foreign keys. + +we might want to try an automated generate of much of this, all combos of T1<->T2, with +o2m or m2o between them, and a third T3 with o2m/m2o to one/both T1/T2. +""" + class Tester(object): def __init__(self, data=None): @@ -120,6 +125,7 @@ class SelfReferentialTest(AssertMixin): assert True class SelfReferentialNoPKTest(AssertMixin): + """test self-referential relationship that joins on a column other than the primary key column""" def setUpAll(self): global table, meta meta = BoundMetaData(testbase.db) @@ -224,7 +230,103 @@ class InheritTestOne(AssertMixin): # the flush will fail if the UOW does not set up a many-to-one DP # attached to a task corresponding to c1, since "child1_id" is not nullable session.flush() + + +class BiDirectionalManyToOneTest(ORMTest): + def define_tables(self, metadata): + global t1, t2, t3, t4 + t1 = Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)), + Column('t2id', Integer, ForeignKey('t2.id')) + ) + t2 = Table('t2', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)), + Column('t1id', Integer, ForeignKey('t1.id', use_alter=True, name="foo_fk")) + ) + t3 = Table('t3', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)), + Column('t1id', Integer, ForeignKey('t1.id'), nullable=False), + Column('t2id', Integer, ForeignKey('t2.id'), nullable=False), + ) + + def test_reflush(self): + class T1(object):pass + class T2(object):pass + class T3(object):pass + + mapper(T1, t1, properties={ + 't2':relation(T2, primaryjoin=t1.c.t2id==t2.c.id) + }) + mapper(T2, t2, properties={ + 't1':relation(T1, primaryjoin=t2.c.t1id==t1.c.id) + }) + mapper(T3, t3, properties={ + 't1':relation(T1), + 't2':relation(T2) + }) + + o1 = T1() + o1.t2 = T2() + sess = create_session() + sess.save(o1) + sess.flush() + # the bug here is that the dependency sort comes up with T1/T2 in a cycle, but there + # are no T1/T2 objects to be saved. therefore no "cyclical subtree" gets generated, + # and one or the other of T1/T2 gets lost, and processors on T3 dont fire off. + # the test will then fail because the FK's on T3 are not nullable. + o3 = T3() + o3.t1 = o1 + o3.t2 = o1.t2 + sess.save(o3) + sess.flush() + + + def test_reflush_2(self): + """a variant on test_reflush()""" + class T1(object):pass + class T2(object):pass + class T3(object):pass + + mapper(T1, t1, properties={ + 't2':relation(T2, primaryjoin=t1.c.t2id==t2.c.id) + }) + mapper(T2, t2, properties={ + 't1':relation(T1, primaryjoin=t2.c.t1id==t1.c.id) + }) + mapper(T3, t3, properties={ + 't1':relation(T1), + 't2':relation(T2) + }) + + o1 = T1() + o1.t2 = T2() + sess = create_session() + sess.save(o1) + sess.flush() + + # in this case, T1, T2, and T3 tasks will all be in the cyclical + # tree normally. the dependency processors for T3 are part of the + # 'extradeps' collection so they all get assembled into the tree + # as well. + o1a = T1() + o2a = T2() + sess.save(o1a) + sess.save(o2a) + o3b = T3() + o3b.t1 = o1a + o3b.t2 = o2a + sess.save(o3b) + + o3 = T3() + o3.t1 = o1 + o3.t2 = o1.t2 + sess.save(o3) + sess.flush() + class BiDirectionalOneToManyTest(AssertMixin): """tests two mappers with a one-to-many relation to each other.""" def setUpAll(self): @@ -370,8 +472,6 @@ class OneToManyManyToOneTest(AssertMixin): ) ) - print str(Person.mapper.props['balls'].primaryjoin) - b = Ball() p = Person() p.balls.append(b) @@ -623,6 +723,7 @@ class OneToManyManyToOneTest(AssertMixin): ]) class SelfReferentialPostUpdateTest(AssertMixin): + """test using post_update on a single self-referential mapper""" def setUpAll(self): global metadata, node_table metadata = BoundMetaData(testbase.db) -- 2.47.2