]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- improvements to the "determine direction" logic of
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 29 Jan 2009 06:40:29 +0000 (06:40 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 29 Jan 2009 06:40:29 +0000 (06:40 +0000)
relation() such that the direction of tricky situations
like mapper(A.join(B)) -> relation-> mapper(B) can be
determined.

CHANGES
lib/sqlalchemy/orm/properties.py
test/orm/relationships.py

diff --git a/CHANGES b/CHANGES
index b463d2814f4e589089eeb7fad62ed2884d31a7ef..e5cc729b9ea463b80e12be94f15a4932a8df84f1 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -12,6 +12,11 @@ CHANGES
       union(query1, query2), select([foo]).select_from(query), 
       etc.
 
+    - improvements to the "determine direction" logic of 
+      relation() such that the direction of tricky situations 
+      like mapper(A.join(B)) -> relation-> mapper(B) can be
+      determined.
+      
 - sql
     - the __selectable__() interface has been replaced entirely
       by __clause_element__().
index f05613f5c0bad899eb75cbb294030757e37ebdd8..73e7943706ba7a3936d16b2b6e314f5e441fe8d0 100644 (file)
@@ -889,29 +889,45 @@ class RelationProperty(StrategizedProperty):
                 self.direction = MANYTOONE
 
         else:
-            for mappedtable, parenttable in [(self.mapper.mapped_table, self.parent.mapped_table), (self.mapper.local_table, self.parent.local_table)]:
-                onetomany = [c for c in self._foreign_keys if mappedtable.c.contains_column(c)]
-                manytoone = [c for c in self._foreign_keys if parenttable.c.contains_column(c)]
-
-                if not onetomany and not manytoone:
-                    raise sa_exc.ArgumentError(
-                        "Can't determine relation direction for relationship '%s' "
-                        "- foreign key columns are present in neither the "
-                        "parent nor the child's mapped tables" %(str(self)))
-                elif onetomany and manytoone:
-                    continue
-                elif onetomany:
+            foreign_keys = [f for c, f in self.synchronize_pairs]
+
+            parentcols = util.column_set(self.parent.mapped_table.c)
+            targetcols = util.column_set(self.mapper.mapped_table.c)
+
+            # fk collection which suggests ONETOMANY.
+            onetomany_fk = targetcols.intersection(foreign_keys)
+
+            # fk collection which suggests MANYTOONE.
+            manytoone_fk = parentcols.intersection(foreign_keys)
+            
+            if not onetomany_fk and not manytoone_fk:
+                raise sa_exc.ArgumentError(
+                    "Can't determine relation direction for relationship '%s' "
+                    "- foreign key columns are present in neither the "
+                    "parent nor the child's mapped tables" % self )
+
+            elif onetomany_fk and manytoone_fk: 
+                # fks on both sides.  do the same
+                # test only based on the local side.
+                referents = [c for c, f in self.synchronize_pairs]
+                onetomany_local = parentcols.intersection(referents)
+                manytoone_local = targetcols.intersection(referents)
+
+                if onetomany_local and not manytoone_local:
                     self.direction = ONETOMANY
-                    break
-                elif manytoone:
+                elif manytoone_local and not onetomany_local:
                     self.direction = MANYTOONE
-                    break
-            else:
+            elif onetomany_fk:
+                self.direction = ONETOMANY
+            elif manytoone_fk:
+                self.direction = MANYTOONE
+                
+            if not self.direction:
                 raise sa_exc.ArgumentError(
                     "Can't determine relation direction for relationship '%s' "
                     "- foreign key columns are present in both the parent and "
                     "the child's mapped tables.  Specify 'foreign_keys' "
-                    "argument." % (str(self)))
+                    "argument." % self)
         
         if self.cascade.delete_orphan and not self.single_parent and \
             (self.direction is MANYTOMANY or self.direction is MANYTOONE):
@@ -1001,7 +1017,7 @@ class RelationProperty(StrategizedProperty):
         
 
     def _refers_to_parent_table(self):
-        return self.parent.mapped_table is self.target or self.parent.mapped_table is self.target
+        return self.parent.mapped_table is self.target
 
     def _is_self_referential(self):
         return self.mapper.common_parent(self.parent)
index 532203ce2040d8defdbe221d5cdbba30c2c0c563..8dfc8f95a69a062bd34404354101a253b04181f4 100644 (file)
@@ -1,7 +1,7 @@
 import testenv; testenv.configure_for_tests()
 import datetime
 from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData
+from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, and_
 from testlib.sa.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers
 from testlib.testing import eq_, startswith_
 from orm import _base, _fixtures
@@ -650,6 +650,78 @@ class RelationTest6(_base.MappedTest):
             [TagInstance(data='iplc_case'), TagInstance(data='not_iplc_case')]
         )
 
+class AmbiguousJoinInterpretedAsSelfRef(_base.MappedTest):
+    """test ambiguous joins due to FKs on both sides treated as self-referential.
+    
+    this mapping is very similar to that of test/orm/inheritance/query.py
+    SelfReferentialTestJoinedToBase , except that inheritance is not used
+    here.
+    
+    """
+    
+    def define_tables(self, metadata):
+        subscriber_table = Table('subscriber', metadata,
+           Column('id', Integer, primary_key=True),
+          )
+
+        address_table = Table('address',
+                 metadata,
+                 Column('subscriber_id', Integer, ForeignKey('subscriber.id'), primary_key=True),
+                 Column('type', String(1), primary_key=True),
+                 )
+
+    @testing.resolve_artifact_names
+    def setup_mappers(self):
+        subscriber_and_address = subscriber.join(address, 
+               and_(address.c.subscriber_id==subscriber.c.id, address.c.type.in_(['A', 'B', 'C'])))
+
+        class Address(_base.ComparableEntity):
+            pass
+
+        class Subscriber(_base.ComparableEntity):
+            pass
+
+        mapper(Address, address)
+
+        mapper(Subscriber, subscriber_and_address, properties={
+           'id':[subscriber.c.id, address.c.subscriber_id],
+           'addresses' : relation(Address, 
+                backref=backref("customer"))
+           })
+        
+    @testing.resolve_artifact_names
+    def test_mapping(self):
+        from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE
+        sess = create_session()
+        assert Subscriber.addresses.property.direction is ONETOMANY
+        assert Address.customer.property.direction is MANYTOONE
+        
+        s1 = Subscriber(type='A',
+                addresses = [
+                    Address(type='D'),
+                    Address(type='E'),
+                ]
+        )
+        a1 = Address(type='B', customer=Subscriber(type='C'))
+        
+        assert s1.addresses[0].customer is s1
+        assert a1.customer.addresses[0] is a1
+        
+        sess.add_all([s1, a1])
+        
+        sess.flush()
+        sess.expunge_all()
+        
+        eq_(
+            sess.query(Subscriber).order_by(Subscriber.type).all(),
+            [
+                Subscriber(id=1, type=u'A'), 
+                Subscriber(id=2, type=u'B'), 
+                Subscriber(id=2, type=u'C')
+            ]
+        )
+
+
 class ManualBackrefTest(_fixtures.FixtureTest):
     """Test explicit relations that are backrefs to each other."""