]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
more fixup to self referential composite primary key mappings
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Oct 2006 03:16:38 +0000 (03:16 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Oct 2006 03:16:38 +0000 (03:16 +0000)
lib/sqlalchemy/orm/properties.py
test/orm/relationships.py

index 2ad2c2b8c9d165d1b2843000c4167781b943652b..11c61246747d5d60672329a543e22b9480c60581 100644 (file)
@@ -207,10 +207,14 @@ class PropertyLoader(StrategizedProperty):
         if self.secondaryjoin is not None:
             return sync.MANYTOMANY
         elif self._is_self_referential():
-            if list(self.foreignkey)[0].primary_key:
-                return sync.MANYTOONE
+            # for a self referential mapper, if the "foreignkey" is a single or composite primary key,
+            # then we are "many to one", since the remote site of the relationship identifies a singular entity.
+            # otherwise we are "one to many".
+            for f in self.foreignkey:
+                if not f.primary_key:
+                    return sync.ONETOMANY
             else:
-                return sync.ONETOMANY
+                return sync.MANYTOONE
         elif len([c for c in self.foreignkey if self.mapper.unjoined_table.corresponding_column(c, False) is not None]):
             return sync.ONETOMANY
         elif len([c for c in self.foreignkey if self.parent.unjoined_table.corresponding_column(c, False) is not None]):
index dcf87c4f6866c42df4970ad5ed9e66c348826ad9..ae0101108746474984993431f69bb91135b46d1b 100644 (file)
@@ -125,7 +125,8 @@ class RelationTest2(testbase.PersistTest):
     def tearDownAll(self):
         metadata.drop_all()    
 
-    def testbasic(self):
+    def testexplicit(self):
+        """test with mappers that have fairly explicit join conditions"""
         class Company(object):
             pass
         class Employee(object):
@@ -171,6 +172,50 @@ class RelationTest2(testbase.PersistTest):
         assert [x.name for x in test_e1.employees] == ['emp2', 'emp3']
         assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1'
         assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5'
+
+    def testimplict(self):
+        """test with mappers that have the most minimal arguments"""
+        class Company(object):
+            pass
+        class Employee(object):
+            def __init__(self, name, company, emp_id, reports_to=None):
+                self.name = name
+                self.company = company
+                self.emp_id = emp_id
+                self.reports_to = reports_to
+
+        mapper(Company, company_tbl)
+        mapper(Employee, employee_tbl, properties= {
+            'company':relation(Company, backref='employees'),
+            'reports_to':relation(Employee, 
+                foreignkey=[employee_tbl.c.company_id, employee_tbl.c.emp_id],
+                backref='employees')
+        })
+
+        sess = create_session()
+        c1 = Company()
+        c2 = Company()
+
+        e1 = Employee('emp1', c1, 1)
+        e2 = Employee('emp2', c1, 2, e1)
+        e3 = Employee('emp3', c1, 3, e1)
+        e4 = Employee('emp4', c1, 4, e3)
+        e5 = Employee('emp5', c2, 1)
+        e6 = Employee('emp6', c2, 2, e5)
+        e7 = Employee('emp7', c2, 3, e5)
+
+        [sess.save(x) for x in [c1,c2]]
+        sess.flush()
+        sess.clear()
+
+        test_c1 = sess.query(Company).get(c1.company_id)
+        test_e1 = sess.query(Employee).get([c1.company_id, e1.emp_id])
+        assert test_e1.name == 'emp1'
+        test_e5 = sess.query(Employee).get([c2.company_id, e5.emp_id])
+        assert test_e5.name == 'emp5'
+        assert [x.name for x in test_e1.employees] == ['emp2', 'emp3']
+        assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1'
+        assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5'