From: Mike Bayer Date: Thu, 29 Jan 2009 06:40:29 +0000 (+0000) Subject: - improvements to the "determine direction" logic of X-Git-Tag: rel_0_5_3~48 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=966119f4d31ec511930024ad4b12d5e53cc8a6ec;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - improvements to the "determine direction" logic of relation() such that the direction of tricky situations like mapper(A.join(B)) -> relation-> mapper(B) can be determined. --- diff --git a/CHANGES b/CHANGES index b463d2814f..e5cc729b9e 100644 --- a/CHANGES +++ b/CHANGES @@ -12,6 +12,11 @@ CHANGES union(query1, query2), select([foo]).select_from(query), etc. + - improvements to the "determine direction" logic of + relation() such that the direction of tricky situations + like mapper(A.join(B)) -> relation-> mapper(B) can be + determined. + - sql - the __selectable__() interface has been replaced entirely by __clause_element__(). diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index f05613f5c0..73e7943706 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -889,29 +889,45 @@ class RelationProperty(StrategizedProperty): self.direction = MANYTOONE else: - for mappedtable, parenttable in [(self.mapper.mapped_table, self.parent.mapped_table), (self.mapper.local_table, self.parent.local_table)]: - onetomany = [c for c in self._foreign_keys if mappedtable.c.contains_column(c)] - manytoone = [c for c in self._foreign_keys if parenttable.c.contains_column(c)] - - if not onetomany and not manytoone: - raise sa_exc.ArgumentError( - "Can't determine relation direction for relationship '%s' " - "- foreign key columns are present in neither the " - "parent nor the child's mapped tables" %(str(self))) - elif onetomany and manytoone: - continue - elif onetomany: + foreign_keys = [f for c, f in self.synchronize_pairs] + + parentcols = util.column_set(self.parent.mapped_table.c) + targetcols = util.column_set(self.mapper.mapped_table.c) + + # fk collection which suggests ONETOMANY. + onetomany_fk = targetcols.intersection(foreign_keys) + + # fk collection which suggests MANYTOONE. + manytoone_fk = parentcols.intersection(foreign_keys) + + if not onetomany_fk and not manytoone_fk: + raise sa_exc.ArgumentError( + "Can't determine relation direction for relationship '%s' " + "- foreign key columns are present in neither the " + "parent nor the child's mapped tables" % self ) + + elif onetomany_fk and manytoone_fk: + # fks on both sides. do the same + # test only based on the local side. + referents = [c for c, f in self.synchronize_pairs] + onetomany_local = parentcols.intersection(referents) + manytoone_local = targetcols.intersection(referents) + + if onetomany_local and not manytoone_local: self.direction = ONETOMANY - break - elif manytoone: + elif manytoone_local and not onetomany_local: self.direction = MANYTOONE - break - else: + elif onetomany_fk: + self.direction = ONETOMANY + elif manytoone_fk: + self.direction = MANYTOONE + + if not self.direction: raise sa_exc.ArgumentError( "Can't determine relation direction for relationship '%s' " "- foreign key columns are present in both the parent and " "the child's mapped tables. Specify 'foreign_keys' " - "argument." % (str(self))) + "argument." % self) if self.cascade.delete_orphan and not self.single_parent and \ (self.direction is MANYTOMANY or self.direction is MANYTOONE): @@ -1001,7 +1017,7 @@ class RelationProperty(StrategizedProperty): def _refers_to_parent_table(self): - return self.parent.mapped_table is self.target or self.parent.mapped_table is self.target + return self.parent.mapped_table is self.target def _is_self_referential(self): return self.mapper.common_parent(self.parent) diff --git a/test/orm/relationships.py b/test/orm/relationships.py index 532203ce20..8dfc8f95a6 100644 --- a/test/orm/relationships.py +++ b/test/orm/relationships.py @@ -1,7 +1,7 @@ import testenv; testenv.configure_for_tests() import datetime from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData +from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, and_ from testlib.sa.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers from testlib.testing import eq_, startswith_ from orm import _base, _fixtures @@ -650,6 +650,78 @@ class RelationTest6(_base.MappedTest): [TagInstance(data='iplc_case'), TagInstance(data='not_iplc_case')] ) +class AmbiguousJoinInterpretedAsSelfRef(_base.MappedTest): + """test ambiguous joins due to FKs on both sides treated as self-referential. + + this mapping is very similar to that of test/orm/inheritance/query.py + SelfReferentialTestJoinedToBase , except that inheritance is not used + here. + + """ + + def define_tables(self, metadata): + subscriber_table = Table('subscriber', metadata, + Column('id', Integer, primary_key=True), + ) + + address_table = Table('address', + metadata, + Column('subscriber_id', Integer, ForeignKey('subscriber.id'), primary_key=True), + Column('type', String(1), primary_key=True), + ) + + @testing.resolve_artifact_names + def setup_mappers(self): + subscriber_and_address = subscriber.join(address, + and_(address.c.subscriber_id==subscriber.c.id, address.c.type.in_(['A', 'B', 'C']))) + + class Address(_base.ComparableEntity): + pass + + class Subscriber(_base.ComparableEntity): + pass + + mapper(Address, address) + + mapper(Subscriber, subscriber_and_address, properties={ + 'id':[subscriber.c.id, address.c.subscriber_id], + 'addresses' : relation(Address, + backref=backref("customer")) + }) + + @testing.resolve_artifact_names + def test_mapping(self): + from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE + sess = create_session() + assert Subscriber.addresses.property.direction is ONETOMANY + assert Address.customer.property.direction is MANYTOONE + + s1 = Subscriber(type='A', + addresses = [ + Address(type='D'), + Address(type='E'), + ] + ) + a1 = Address(type='B', customer=Subscriber(type='C')) + + assert s1.addresses[0].customer is s1 + assert a1.customer.addresses[0] is a1 + + sess.add_all([s1, a1]) + + sess.flush() + sess.expunge_all() + + eq_( + sess.query(Subscriber).order_by(Subscriber.type).all(), + [ + Subscriber(id=1, type=u'A'), + Subscriber(id=2, type=u'B'), + Subscriber(id=2, type=u'C') + ] + ) + + class ManualBackrefTest(_fixtures.FixtureTest): """Test explicit relations that are backrefs to each other."""