if self.secondaryjoin is not None:
return sync.MANYTOMANY
elif self._is_self_referential():
- if list(self.foreignkey)[0].primary_key:
- return sync.MANYTOONE
+ # for a self referential mapper, if the "foreignkey" is a single or composite primary key,
+ # then we are "many to one", since the remote site of the relationship identifies a singular entity.
+ # otherwise we are "one to many".
+ for f in self.foreignkey:
+ if not f.primary_key:
+ return sync.ONETOMANY
else:
- return sync.ONETOMANY
+ return sync.MANYTOONE
elif len([c for c in self.foreignkey if self.mapper.unjoined_table.corresponding_column(c, False) is not None]):
return sync.ONETOMANY
elif len([c for c in self.foreignkey if self.parent.unjoined_table.corresponding_column(c, False) is not None]):
def tearDownAll(self):
metadata.drop_all()
- def testbasic(self):
+ def testexplicit(self):
+ """test with mappers that have fairly explicit join conditions"""
class Company(object):
pass
class Employee(object):
assert [x.name for x in test_e1.employees] == ['emp2', 'emp3']
assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1'
assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5'
+
+ def testimplict(self):
+ """test with mappers that have the most minimal arguments"""
+ class Company(object):
+ pass
+ class Employee(object):
+ def __init__(self, name, company, emp_id, reports_to=None):
+ self.name = name
+ self.company = company
+ self.emp_id = emp_id
+ self.reports_to = reports_to
+
+ mapper(Company, company_tbl)
+ mapper(Employee, employee_tbl, properties= {
+ 'company':relation(Company, backref='employees'),
+ 'reports_to':relation(Employee,
+ foreignkey=[employee_tbl.c.company_id, employee_tbl.c.emp_id],
+ backref='employees')
+ })
+
+ sess = create_session()
+ c1 = Company()
+ c2 = Company()
+
+ e1 = Employee('emp1', c1, 1)
+ e2 = Employee('emp2', c1, 2, e1)
+ e3 = Employee('emp3', c1, 3, e1)
+ e4 = Employee('emp4', c1, 4, e3)
+ e5 = Employee('emp5', c2, 1)
+ e6 = Employee('emp6', c2, 2, e5)
+ e7 = Employee('emp7', c2, 3, e5)
+
+ [sess.save(x) for x in [c1,c2]]
+ sess.flush()
+ sess.clear()
+
+ test_c1 = sess.query(Company).get(c1.company_id)
+ test_e1 = sess.query(Employee).get([c1.company_id, e1.emp_id])
+ assert test_e1.name == 'emp1'
+ test_e5 = sess.query(Employee).get([c2.company_id, e5.emp_id])
+ assert test_e5.name == 'emp5'
+ assert [x.name for x in test_e1.employees] == ['emp2', 'emp3']
+ assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1'
+ assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5'