From: Mike Bayer Date: Sat, 10 May 2008 17:42:09 +0000 (+0000) Subject: merged r4720 from 04 branch for [ticket:1036] X-Git-Tag: rel_0_5beta1~112 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0ca037c65ed450b030c1306c04c18058ccf2c979;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git merged r4720 from 04 branch for [ticket:1036] --- diff --git a/CHANGES b/CHANGES index 894be6116d..534867319c 100644 --- a/CHANGES +++ b/CHANGES @@ -88,6 +88,10 @@ user_defined_state functions and other expressions. (partial progress towards [ticket:610]) + - repaired single table inheritance such that you + can single-table inherit from a joined-table inherting + mapper without issue [ticket:1036]. + - Fixed "concatenate tuple" bug which could occur with Query.order_by() if clause adaption had taken place. [ticket:1027] diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 234f339059..5569f9216b 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -434,8 +434,9 @@ class Mapper(object): # inherit_condition is optional. if self.local_table is None: self.local_table = self.inherits.local_table + self.mapped_table = self.inherits.mapped_table self.single = True - if not self.local_table is self.inherits.local_table: + elif not self.local_table is self.inherits.local_table: if self.concrete: self.mapped_table = self.local_table for mapper in self.iterate_to_root(): @@ -1556,7 +1557,7 @@ class Mapper(object): for mapper in util.reversed(list(self.iterate_to_root())): if mapper.local_table in tables: start = True - if start: + if start and not mapper.single: allconds.append(visitors.cloned_traverse(mapper.inherit_condition, {}, {'binary':visit_binary})) cond = sql.and_(*allconds) diff --git a/test/orm/inheritance/single.py b/test/orm/inheritance/single.py index dabb701cd9..58e7ad82a6 100644 --- a/test/orm/inheritance/single.py +++ b/test/orm/inheritance/single.py @@ -2,11 +2,10 @@ import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from testlib import * +from testlib.fixtures import Base - -class SingleInheritanceTest(TestBase, AssertsExecutionResults): - def setUpAll(self): - metadata = MetaData(testing.db) +class SingleInheritanceTest(ORMTest): + def define_tables(self, metadata): global employees_table employees_table = Table('employees', metadata, Column('employee_id', Integer, primary_key=True), @@ -15,43 +14,27 @@ class SingleInheritanceTest(TestBase, AssertsExecutionResults): Column('engineer_info', String(50)), Column('type', String(20)) ) - employees_table.create() - def tearDownAll(self): - employees_table.drop() - def testbasic(self): - class Employee(object): - def __init__(self, name): - self.name = name - def __repr__(self): - return self.__class__.__name__ + " " + self.name + def test_single_inheritance(self): + class Employee(Base): + pass class Manager(Employee): - def __init__(self, name, manager_data): - self.name = name - self.manager_data = manager_data - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " " + self.manager_data - + pass class Engineer(Employee): - def __init__(self, name, engineer_info): - self.name = name - self.engineer_info = engineer_info - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " " + self.engineer_info - + pass class JuniorEngineer(Engineer): pass - employee_mapper = mapper(Employee, employees_table, polymorphic_on=employees_table.c.type) - manager_mapper = mapper(Manager, inherits=employee_mapper, polymorphic_identity='manager') - engineer_mapper = mapper(Engineer, inherits=employee_mapper, polymorphic_identity='engineer') - junior_engineer = mapper(JuniorEngineer, inherits=engineer_mapper, polymorphic_identity='juniorengineer') + mapper(Employee, employees_table, polymorphic_on=employees_table.c.type) + mapper(Manager, inherits=Employee, polymorphic_identity='manager') + mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') + mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer') session = create_session() - m1 = Manager('Tom', 'knows how to manage things') - e1 = Engineer('Kurt', 'knows how to hack') - e2 = JuniorEngineer('Ed', 'oh that ed') + m1 = Manager(name='Tom', manager_data='knows how to manage things') + e1 = Engineer(name='Kurt', engineer_info='knows how to hack') + e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed') session.save(m1) session.save(e1) session.save(e2) @@ -66,5 +49,66 @@ class SingleInheritanceTest(TestBase, AssertsExecutionResults): session.expire(m1, ['manager_data']) self.assertEquals(m1.manager_data, "knows how to manage things") +class SingleOnJoinedTest(ORMTest): + def define_tables(self, metadata): + global persons_table, employees_table + + persons_table = Table('persons', metadata, + Column('person_id', Integer, primary_key=True), + Column('name', String(50)), + Column('type', String(20), nullable=False) + ) + + employees_table = Table('employees', metadata, + Column('person_id', Integer, ForeignKey('persons.person_id'),primary_key=True), + Column('employee_data', String(50)), + Column('manager_data', String(50)), + ) + + def test_single_on_joined(self): + class Person(Base): + pass + class Employee(Person): + pass + class Manager(Employee): + pass + + mapper(Person, persons_table, polymorphic_on=persons_table.c.type, polymorphic_identity='person') + mapper(Employee, employees_table, inherits=Person,polymorphic_identity='engineer') + mapper(Manager, inherits=Employee,polymorphic_identity='manager') + + sess = create_session() + sess.save(Person(name='p1')) + sess.save(Employee(name='e1', employee_data='ed1')) + sess.save(Manager(name='m1', employee_data='ed2', manager_data='md1')) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(Person).order_by(Person.person_id).all(), [ + Person(name='p1'), + Employee(name='e1', employee_data='ed1'), + Manager(name='m1', employee_data='ed2', manager_data='md1') + ]) + sess.clear() + + self.assertEquals(sess.query(Employee).order_by(Person.person_id).all(), [ + Employee(name='e1', employee_data='ed1'), + Manager(name='m1', employee_data='ed2', manager_data='md1') + ]) + sess.clear() + + self.assertEquals(sess.query(Manager).order_by(Person.person_id).all(), [ + Manager(name='m1', employee_data='ed2', manager_data='md1') + ]) + sess.clear() + + def go(): + self.assertEquals(sess.query(Person).with_polymorphic('*').order_by(Person.person_id).all(), [ + Person(name='p1'), + Employee(name='e1', employee_data='ed1'), + Manager(name='m1', employee_data='ed2', manager_data='md1') + ]) + self.assert_sql_count(testing.db, go, 1) + if __name__ == '__main__': testenv.main()