From: Mike Bayer Date: Fri, 3 Nov 2006 19:57:39 +0000 (+0000) Subject: - improvement to single table inheritance to load full hierarchies beneath X-Git-Tag: rel_0_3_1~15 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=03068225263255a5f74ff97ad3d5d287d1b233da;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - improvement to single table inheritance to load full hierarchies beneath the target class --- diff --git a/CHANGES b/CHANGES index f8ba46f0f8..058804bcfe 100644 --- a/CHANGES +++ b/CHANGES @@ -24,6 +24,8 @@ that the class of object attached to a parent object is appropriate (i.e. if A.items stores B objects, raise an error if a C is appended to A.items) - new extension sqlalchemy.ext.associationproxy, provides transparent "association object" mappings. new example examples/association/proxied_association.py illustrates. +- improvement to single table inheritance to load full hierarchies beneath +the target class 0.3.0 - General: diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 84a2540b76..3327f13c34 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -595,13 +595,17 @@ class Mapper(object): m = m.inherits def polymorphic_iterator(self): - m = self.base_mapper() + """iterates through the collection including this mapper and all descendant mappers. + + this includes not just the immediately inheriting mappers but all their inheriting mappers as well. + + To iterate through an entire hierarchy, use mapper.base_mapper().polymorphic_iterator().""" def iterate(m): yield m for mapper in m._inheriting_mappers: for x in iterate(mapper): yield x - return iterate(m) + return iterate(self) def add_properties(self, dict_of_properties): """adds the given dictionary of properties to this mapper, using add_property.""" @@ -831,7 +835,7 @@ class Mapper(object): updated_objects = util.Set() table_to_mapper = {} - for mapper in self.polymorphic_iterator(): + for mapper in self.base_mapper().polymorphic_iterator(): for t in mapper.tables: table_to_mapper.setdefault(t, mapper) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index e257a1cbe1..208c7372b8 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -414,7 +414,7 @@ class Query(object): raise exceptions.ArgumentError("Unknown lockmode '%s'" % lockmode) if self.mapper.single and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None: - whereclause = sql.and_(whereclause, self.mapper.polymorphic_on==self.mapper.polymorphic_identity) + whereclause = sql.and_(whereclause, self.mapper.polymorphic_on.in_(*[m.polymorphic_identity for m in self.mapper.polymorphic_iterator()])) alltables = [] for l in [sql_util.TableFinder(x) for x in from_obj]: diff --git a/test/orm/alltests.py b/test/orm/alltests.py index daa7e38a8f..70d0bb6f4c 100644 --- a/test/orm/alltests.py +++ b/test/orm/alltests.py @@ -27,6 +27,7 @@ def suite(): 'orm.inheritance', 'orm.inheritance2', 'orm.inheritance3', + 'orm.single', 'orm.polymorph' ) alltests = unittest.TestSuite() diff --git a/test/orm/single.py b/test/orm/single.py new file mode 100644 index 0000000000..d78084fe11 --- /dev/null +++ b/test/orm/single.py @@ -0,0 +1,63 @@ +from sqlalchemy import * +import testbase + +class SingleInheritanceTest(testbase.AssertMixin): + def setUpAll(self): + metadata = BoundMetaData(testbase.db) + global employees_table + employees_table = Table('employees', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('manager_data', String(50)), + Column('engineer_info', String(50)), + Column('type', String(20)) + ) + employees_table.create() + def tearDownAll(self): + employees_table.drop() + def testbasic(self): + class Employee(object): + def __init__(self, name): + self.name = name + def __repr__(self): + return self.__class__.__name__ + " " + self.name + + class Manager(Employee): + def __init__(self, name, manager_data): + self.name = name + self.manager_data = manager_data + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " " + self.manager_data + + class Engineer(Employee): + def __init__(self, name, engineer_info): + self.name = name + self.engineer_info = engineer_info + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " " + self.engineer_info + + class JuniorEngineer(Engineer): + pass + + employee_mapper = mapper(Employee, employees_table, polymorphic_on=employees_table.c.type) + manager_mapper = mapper(Manager, inherits=employee_mapper, polymorphic_identity='manager') + engineer_mapper = mapper(Engineer, inherits=employee_mapper, polymorphic_identity='engineer') + junior_engineer = mapper(JuniorEngineer, inherits=engineer_mapper, polymorphic_identity='juniorengineer') + + session = create_session() + + m1 = Manager('Tom', 'knows how to manage things') + e1 = Engineer('Kurt', 'knows how to hack') + e2 = JuniorEngineer('Ed', 'oh that ed') + session.save(m1) + session.save(e1) + session.save(e2) + session.flush() + + assert session.query(Employee).select() == [m1, e1, e2] + assert session.query(Engineer).select() == [e1, e2] + assert session.query(Manager).select() == [m1] + assert session.query(JuniorEngineer).select() == [e2] + +if __name__ == '__main__': + testbase.main() \ No newline at end of file