]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- improvement to single table inheritance to load full hierarchies beneath
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 3 Nov 2006 19:57:39 +0000 (19:57 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 3 Nov 2006 19:57:39 +0000 (19:57 +0000)
the target class

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
test/orm/alltests.py
test/orm/single.py [new file with mode: 0644]

diff --git a/CHANGES b/CHANGES
index f8ba46f0f8e744f5fde27c6b1ca94909503ae223..058804bcfe2d26fe7149da955d8755e02ec3cf26 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -24,6 +24,8 @@ that the class of object attached to a parent object is appropriate
 (i.e. if A.items stores B objects, raise an error if a C is appended to A.items)
 - new extension sqlalchemy.ext.associationproxy, provides transparent "association object"
 mappings.  new example examples/association/proxied_association.py illustrates.
+- improvement to single table inheritance to load full hierarchies beneath 
+the target class
 
 0.3.0
 - General:
index 84a2540b76a5059d5b702c6a4cdcf3eb733169b1..3327f13c34a9cf360f0cd2202deb7bc319e67fb0 100644 (file)
@@ -595,13 +595,17 @@ class Mapper(object):
             m = m.inherits
     
     def polymorphic_iterator(self):
-        m = self.base_mapper()
+        """iterates through the collection including this mapper and all descendant mappers.
+        
+        this includes not just the immediately inheriting mappers but all their inheriting mappers as well.
+        
+        To iterate through an entire hierarchy, use mapper.base_mapper().polymorphic_iterator()."""
         def iterate(m):
             yield m
             for mapper in m._inheriting_mappers:
                 for x in iterate(mapper):
                     yield x
-        return iterate(m)
+        return iterate(self)
                 
     def add_properties(self, dict_of_properties):
         """adds the given dictionary of properties to this mapper, using add_property."""
@@ -831,7 +835,7 @@ class Mapper(object):
         updated_objects = util.Set()
         
         table_to_mapper = {}
-        for mapper in self.polymorphic_iterator():
+        for mapper in self.base_mapper().polymorphic_iterator():
             for t in mapper.tables:
                 table_to_mapper.setdefault(t, mapper)
 
index e257a1cbe140a281816c56ca6eea8bf3ae5758f3..208c7372b870efd630fe21c347b0aa1514557502 100644 (file)
@@ -414,7 +414,7 @@ class Query(object):
             raise exceptions.ArgumentError("Unknown lockmode '%s'" % lockmode)
         
         if self.mapper.single and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None:
-            whereclause = sql.and_(whereclause, self.mapper.polymorphic_on==self.mapper.polymorphic_identity)
+            whereclause = sql.and_(whereclause, self.mapper.polymorphic_on.in_(*[m.polymorphic_identity for m in self.mapper.polymorphic_iterator()]))
         
         alltables = []
         for l in [sql_util.TableFinder(x) for x in from_obj]:
index daa7e38a8f06ba7f4c2f5cf3e685afc5b99c032d..70d0bb6f4c26dc30a7d661e8cdcf093b8cfbd74f 100644 (file)
@@ -27,6 +27,7 @@ def suite():
         'orm.inheritance',
         'orm.inheritance2',
         'orm.inheritance3',
+        'orm.single',
         'orm.polymorph'        
         )
     alltests = unittest.TestSuite()
diff --git a/test/orm/single.py b/test/orm/single.py
new file mode 100644 (file)
index 0000000..d78084f
--- /dev/null
@@ -0,0 +1,63 @@
+from sqlalchemy import *
+import testbase
+
+class SingleInheritanceTest(testbase.AssertMixin):
+    def setUpAll(self):
+        metadata = BoundMetaData(testbase.db)
+        global employees_table
+        employees_table = Table('employees', metadata, 
+            Column('employee_id', Integer, primary_key=True),
+            Column('name', String(50)),
+            Column('manager_data', String(50)),
+            Column('engineer_info', String(50)),
+            Column('type', String(20))
+        )
+        employees_table.create()
+    def tearDownAll(self):
+        employees_table.drop()
+    def testbasic(self):
+        class Employee(object):
+            def __init__(self, name):
+                self.name = name
+            def __repr__(self):
+                return self.__class__.__name__ + " " + self.name
+
+        class Manager(Employee):
+            def __init__(self, name, manager_data):
+                self.name = name
+                self.manager_data = manager_data
+            def __repr__(self):
+                return self.__class__.__name__ + " " + self.name + " " +  self.manager_data
+
+        class Engineer(Employee):
+            def __init__(self, name, engineer_info):
+                self.name = name
+                self.engineer_info = engineer_info
+            def __repr__(self):
+                return self.__class__.__name__ + " " + self.name + " " +  self.engineer_info
+
+        class JuniorEngineer(Engineer):
+            pass
+
+        employee_mapper = mapper(Employee, employees_table, polymorphic_on=employees_table.c.type)
+        manager_mapper = mapper(Manager, inherits=employee_mapper, polymorphic_identity='manager')
+        engineer_mapper = mapper(Engineer, inherits=employee_mapper, polymorphic_identity='engineer')
+        junior_engineer = mapper(JuniorEngineer, inherits=engineer_mapper, polymorphic_identity='juniorengineer')
+
+        session = create_session()
+
+        m1 = Manager('Tom', 'knows how to manage things')
+        e1 = Engineer('Kurt', 'knows how to hack')
+        e2 = JuniorEngineer('Ed', 'oh that ed')
+        session.save(m1)
+        session.save(e1)
+        session.save(e2)
+        session.flush()
+
+        assert session.query(Employee).select() == [m1, e1, e2]
+        assert session.query(Engineer).select() == [e1, e2]
+        assert session.query(Manager).select() == [m1]
+        assert session.query(JuniorEngineer).select() == [e2]
+        
+if __name__ == '__main__':
+    testbase.main()
\ No newline at end of file