From: Mike Bayer Date: Sat, 10 May 2008 17:31:07 +0000 (+0000) Subject: - repaired single table inheritance such that you X-Git-Tag: rel_0_4_6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6595a5af31ee473f4388db5917c59f0965a64279;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - repaired single table inheritance such that you can single-table inherit from a joined-table inherting mapper without issue [ticket:1036]. --- diff --git a/CHANGES b/CHANGES index 35d53ab618..df9120b64b 100644 --- a/CHANGES +++ b/CHANGES @@ -28,6 +28,10 @@ CHANGES functions and other expressions. (partial progress towards [ticket:610]) + - repaired single table inheritance such that you + can single-table inherit from a joined-table inherting + mapper without issue [ticket:1036]. + - Fixed "concatenate tuple" bug which could occur with Query.order_by() if clause adaption had taken place. [ticket:1027] diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index ba0644758f..448d797d06 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -420,8 +420,9 @@ class Mapper(object): # inherit_condition is optional. if self.local_table is None: self.local_table = self.inherits.local_table + self.mapped_table = self.inherits.mapped_table self.single = True - if not self.local_table is self.inherits.local_table: + elif not self.local_table is self.inherits.local_table: if self.concrete: self.mapped_table = self.local_table for mapper in self.iterate_to_root(): @@ -1592,7 +1593,8 @@ class Mapper(object): for mapper in self.iterate_to_root(): if mapper is base_mapper: break - allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary)) + if not mapper.single: + allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary)) return sql.and_(*allconds), param_names diff --git a/test/orm/inheritance/single.py b/test/orm/inheritance/single.py index 81223cc02e..2241afb0f8 100644 --- a/test/orm/inheritance/single.py +++ b/test/orm/inheritance/single.py @@ -2,11 +2,10 @@ import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from testlib import * +from testlib.fixtures import Base - -class SingleInheritanceTest(TestBase, AssertsExecutionResults): - def setUpAll(self): - metadata = MetaData(testing.db) +class SingleInheritanceTest(ORMTest): + def define_tables(self, metadata): global employees_table employees_table = Table('employees', metadata, Column('employee_id', Integer, primary_key=True), @@ -15,43 +14,27 @@ class SingleInheritanceTest(TestBase, AssertsExecutionResults): Column('engineer_info', String(50)), Column('type', String(20)) ) - employees_table.create() - def tearDownAll(self): - employees_table.drop() - def testbasic(self): - class Employee(object): - def __init__(self, name): - self.name = name - def __repr__(self): - return self.__class__.__name__ + " " + self.name + def test_single_inheritance(self): + class Employee(Base): + pass class Manager(Employee): - def __init__(self, name, manager_data): - self.name = name - self.manager_data = manager_data - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " " + self.manager_data - + pass class Engineer(Employee): - def __init__(self, name, engineer_info): - self.name = name - self.engineer_info = engineer_info - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " " + self.engineer_info - + pass class JuniorEngineer(Engineer): pass - employee_mapper = mapper(Employee, employees_table, polymorphic_on=employees_table.c.type) - manager_mapper = mapper(Manager, inherits=employee_mapper, polymorphic_identity='manager') - engineer_mapper = mapper(Engineer, inherits=employee_mapper, polymorphic_identity='engineer') - junior_engineer = mapper(JuniorEngineer, inherits=engineer_mapper, polymorphic_identity='juniorengineer') + mapper(Employee, employees_table, polymorphic_on=employees_table.c.type) + mapper(Manager, inherits=Employee, polymorphic_identity='manager') + mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') + mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer') session = create_session() - m1 = Manager('Tom', 'knows how to manage things') - e1 = Engineer('Kurt', 'knows how to hack') - e2 = JuniorEngineer('Ed', 'oh that ed') + m1 = Manager(name='Tom', manager_data='knows how to manage things') + e1 = Engineer(name='Kurt', engineer_info='knows how to hack') + e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed') session.save(m1) session.save(e1) session.save(e2) @@ -62,5 +45,66 @@ class SingleInheritanceTest(TestBase, AssertsExecutionResults): assert session.query(Manager).all() == [m1] assert session.query(JuniorEngineer).all() == [e2] +class SingleOnJoinedTest(ORMTest): + def define_tables(self, metadata): + global persons_table, employees_table + + persons_table = Table('persons', metadata, + Column('person_id', Integer, primary_key=True), + Column('name', String(50)), + Column('type', String(20), nullable=False) + ) + + employees_table = Table('employees', metadata, + Column('person_id', Integer, ForeignKey('persons.person_id'),primary_key=True), + Column('employee_data', String(50)), + Column('manager_data', String(50)), + ) + + def test_single_on_joined(self): + class Person(Base): + pass + class Employee(Person): + pass + class Manager(Employee): + pass + + mapper(Person, persons_table, polymorphic_on=persons_table.c.type, polymorphic_identity='person') + mapper(Employee, employees_table, inherits=Person,polymorphic_identity='engineer') + mapper(Manager, inherits=Employee,polymorphic_identity='manager') + + sess = create_session() + sess.save(Person(name='p1')) + sess.save(Employee(name='e1', employee_data='ed1')) + sess.save(Manager(name='m1', employee_data='ed2', manager_data='md1')) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(Person).order_by(Person.person_id).all(), [ + Person(name='p1'), + Employee(name='e1', employee_data='ed1'), + Manager(name='m1', employee_data='ed2', manager_data='md1') + ]) + sess.clear() + + self.assertEquals(sess.query(Employee).order_by(Person.person_id).all(), [ + Employee(name='e1', employee_data='ed1'), + Manager(name='m1', employee_data='ed2', manager_data='md1') + ]) + sess.clear() + + self.assertEquals(sess.query(Manager).order_by(Person.person_id).all(), [ + Manager(name='m1', employee_data='ed2', manager_data='md1') + ]) + sess.clear() + + def go(): + self.assertEquals(sess.query(Person).with_polymorphic('*').order_by(Person.person_id).all(), [ + Person(name='p1'), + Employee(name='e1', employee_data='ed1'), + Manager(name='m1', employee_data='ed2', manager_data='md1') + ]) + self.assert_sql_count(testing.db, go, 1) + if __name__ == '__main__': testenv.main()