From: Mike Bayer Date: Sun, 8 Oct 2006 03:16:38 +0000 (+0000) Subject: more fixup to self referential composite primary key mappings X-Git-Tag: rel_0_3_0~74 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b0ffcbc264f6a92ba5092e5d785a2dbfe418c307;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git more fixup to self referential composite primary key mappings --- diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 2ad2c2b8c9..11c6124674 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -207,10 +207,14 @@ class PropertyLoader(StrategizedProperty): if self.secondaryjoin is not None: return sync.MANYTOMANY elif self._is_self_referential(): - if list(self.foreignkey)[0].primary_key: - return sync.MANYTOONE + # for a self referential mapper, if the "foreignkey" is a single or composite primary key, + # then we are "many to one", since the remote site of the relationship identifies a singular entity. + # otherwise we are "one to many". + for f in self.foreignkey: + if not f.primary_key: + return sync.ONETOMANY else: - return sync.ONETOMANY + return sync.MANYTOONE elif len([c for c in self.foreignkey if self.mapper.unjoined_table.corresponding_column(c, False) is not None]): return sync.ONETOMANY elif len([c for c in self.foreignkey if self.parent.unjoined_table.corresponding_column(c, False) is not None]): diff --git a/test/orm/relationships.py b/test/orm/relationships.py index dcf87c4f68..ae01011087 100644 --- a/test/orm/relationships.py +++ b/test/orm/relationships.py @@ -125,7 +125,8 @@ class RelationTest2(testbase.PersistTest): def tearDownAll(self): metadata.drop_all() - def testbasic(self): + def testexplicit(self): + """test with mappers that have fairly explicit join conditions""" class Company(object): pass class Employee(object): @@ -171,6 +172,50 @@ class RelationTest2(testbase.PersistTest): assert [x.name for x in test_e1.employees] == ['emp2', 'emp3'] assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1' assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5' + + def testimplict(self): + """test with mappers that have the most minimal arguments""" + class Company(object): + pass + class Employee(object): + def __init__(self, name, company, emp_id, reports_to=None): + self.name = name + self.company = company + self.emp_id = emp_id + self.reports_to = reports_to + + mapper(Company, company_tbl) + mapper(Employee, employee_tbl, properties= { + 'company':relation(Company, backref='employees'), + 'reports_to':relation(Employee, + foreignkey=[employee_tbl.c.company_id, employee_tbl.c.emp_id], + backref='employees') + }) + + sess = create_session() + c1 = Company() + c2 = Company() + + e1 = Employee('emp1', c1, 1) + e2 = Employee('emp2', c1, 2, e1) + e3 = Employee('emp3', c1, 3, e1) + e4 = Employee('emp4', c1, 4, e3) + e5 = Employee('emp5', c2, 1) + e6 = Employee('emp6', c2, 2, e5) + e7 = Employee('emp7', c2, 3, e5) + + [sess.save(x) for x in [c1,c2]] + sess.flush() + sess.clear() + + test_c1 = sess.query(Company).get(c1.company_id) + test_e1 = sess.query(Employee).get([c1.company_id, e1.emp_id]) + assert test_e1.name == 'emp1' + test_e5 = sess.query(Employee).get([c2.company_id, e5.emp_id]) + assert test_e5.name == 'emp5' + assert [x.name for x in test_e1.employees] == ['emp2', 'emp3'] + assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1' + assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5'