From: Mike Bayer Date: Fri, 4 Apr 2008 20:21:03 +0000 (+0000) Subject: refined "local_remote_pairs" a bit to account for the same columns repeated multiple... X-Git-Tag: rel_0_5beta1~198 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5b3cddc48e5b436a0c46f0df3b016a837d823c92;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git refined "local_remote_pairs" a bit to account for the same columns repeated multiple times --- diff --git a/CHANGES b/CHANGES index 750410db26..897b84188a 100644 --- a/CHANGES +++ b/CHANGES @@ -4,6 +4,12 @@ CHANGES 0.4.6 ===== +- orm + - Fix to the recent relation() refactoring which fixes + exotic relations which join between local and remote table + multiple times, with a common column shared between the + joins. + - sql - Fixed bug with union() when applied to non-Table connected select statements diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index a171923fe9..f18a6bddec 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -572,7 +572,7 @@ class PropertyLoader(StrategizedProperty): self.synchronize_pairs = eq_pairs if self.secondaryjoin: - sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys) + sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=self.viewonly) sq_pairs = [(l, r) for l, r in sq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)] if not sq_pairs: @@ -590,24 +590,23 @@ class PropertyLoader(StrategizedProperty): else: self.secondary_synchronize_pairs = None - def local_remote_pairs(self): - return zip(self.local_side, self.remote_side) - local_remote_pairs = property(local_remote_pairs) - def __determine_remote_side(self): if self.remote_side: if self.direction is MANYTOONE: eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_referenced_keys=self.remote_side, any_operator=True) else: eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self.remote_side, any_operator=True) - - if self.secondaryjoin: - sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self.foreign_keys, any_operator=True) - sq_pairs = [(l, r) for l, r in sq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)] - eq_pairs += sq_pairs else: - eq_pairs = zip(self._opposite_side, self.foreign_keys) - + if self.viewonly: + eq_pairs = self.synchronize_pairs + else: + eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self.foreign_keys, any_operator=True) + if self.secondaryjoin: + sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self.foreign_keys, any_operator=True) + eq_pairs += sq_pairs + eq_pairs = [(l, r) for l, r in eq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)] + + self.local_remote_pairs = eq_pairs if self.direction is MANYTOONE: self.remote_side, self.local_side = [util.OrderedSet(s) for s in zip(*eq_pairs)] else: diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index e4c85b5d71..3bbb380d13 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -354,13 +354,13 @@ class LazyLoader(AbstractRelationLoader): equated_columns = {} secondaryjoin = prop.secondaryjoin - equated = dict(prop.local_remote_pairs) + local = prop.local_side def should_bind(targetcol, othercol): if reverse_direction and not secondaryjoin: - return othercol in equated + return othercol in local else: - return targetcol in equated + return targetcol in local def visit_binary(binary): leftcol = binary.left diff --git a/test/orm/relationships.py b/test/orm/relationships.py index 89fd86f275..31d8ab49d6 100644 --- a/test/orm/relationships.py +++ b/test/orm/relationships.py @@ -863,6 +863,7 @@ class ViewOnlyTest2(ORMTest): assert set([x.t2id for x in c1.t2_view]) == set([c2b.t2id]) class ViewOnlyTest3(ORMTest): + """test relating on a join that has no equated columns""" def define_tables(self, metadata): global foos, bars foos = Table('foos', metadata, Column('id', Integer, primary_key=True)) @@ -893,6 +894,88 @@ class ViewOnlyTest3(ORMTest): self.assertEquals(sess.query(Foo).filter_by(id=4).one(), Foo(id=4, bars=[Bar(fid=2), Bar(fid=3)])) self.assertEquals(sess.query(Foo).filter_by(id=9).one(), Foo(id=9, bars=[Bar(fid=2), Bar(fid=3), Bar(fid=6), Bar(fid=7)])) +class ViewOnlyTest4(ORMTest): + """test relating on a join that contains the same 'remote' column twice""" + def define_tables(self, metadata): + global foos, bars + foos = Table('foos', metadata, Column('id', Integer, primary_key=True), + Column('bid1', Integer,ForeignKey('bars.id')), + Column('bid2', Integer,ForeignKey('bars.id'))) + + bars = Table('bars', metadata, Column('id', Integer, primary_key=True), Column('data', String(50))) + + def test_relation_on_or(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass + + mapper(Foo, foos, properties={ + 'bars':relation(Bar, primaryjoin=or_(bars.c.id==foos.c.bid1, bars.c.id==foos.c.bid2), uselist=True, viewonly=True) + }) + + mapper(Bar, bars) + sess = create_session() + b1 = Bar(id=1, data='b1') + b2 = Bar(id=2, data='b2') + b3 = Bar(id=3, data='b3') + f1 = Foo(bid1=1, bid2=2) + f2 = Foo(bid1=3, bid2=None) + sess.save(b1) + sess.save(b2) + sess.save(b3) + sess.save(f1) + sess.save(f2) + sess.flush() + + sess.clear() + self.assertEquals(sess.query(Foo).filter_by(id=f1.id).one(), Foo(bars=[Bar(data='b1'), Bar(data='b2')])) + self.assertEquals(sess.query(Foo).filter_by(id=f2.id).one(), Foo(bars=[Bar(data='b3')])) + +class ViewOnlyTest5(ORMTest): + """test relating on a join that contains the same 'local' column twice""" + def define_tables(self, metadata): + global foos, bars + foos = Table('foos', metadata, Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) + + bars = Table('bars', metadata, Column('id', Integer, primary_key=True), + Column('fid1', Integer, ForeignKey('foos.id')), + Column('fid2', Integer, ForeignKey('foos.id')), + Column('data', String(50)) + ) + + def test_relation_on_or(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass + + mapper(Foo, foos, properties={ + 'bars':relation(Bar, primaryjoin=or_(bars.c.fid1==foos.c.id, bars.c.fid2==foos.c.id), viewonly=True) + }) + + mapper(Bar, bars) + sess = create_session() + f1 = Foo(id=1, data='f1') + f2 = Foo(id=2, data='f2') + b1 = Bar(fid1=1, data='b1') + b2 = Bar(fid2=1, data='b2') + b3 = Bar(fid1=2, data='b3') + b4 = Bar(fid1=1, fid2=2, data='b4') + sess.save(b1) + sess.save(b2) + sess.save(b3) + sess.save(b4) + sess.save(f1) + sess.save(f2) + sess.flush() + + sess.clear() + self.assertEquals(sess.query(Foo).filter_by(id=f1.id).one(), Foo(bars=[Bar(data='b1'), Bar(data='b2'), Bar(data='b4')])) + self.assertEquals(sess.query(Foo).filter_by(id=f2.id).one(), Foo(bars=[Bar(data='b3'), Bar(data='b4')])) + class InvalidRelationEscalationTest(ORMTest): def define_tables(self, metadata): global foos, bars, Foo, Bar