]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- repaired single table inheritance such that you rel_0_4_6
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 10 May 2008 17:31:07 +0000 (17:31 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 10 May 2008 17:31:07 +0000 (17:31 +0000)
can single-table inherit from a joined-table inherting
mapper without issue [ticket:1036].

CHANGES
lib/sqlalchemy/orm/mapper.py
test/orm/inheritance/single.py

diff --git a/CHANGES b/CHANGES
index 35d53ab61815ae534f581315052f51dfcf63c0a1..df9120b64b7772559cdbfa7677ea90e3833b57bd 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -28,6 +28,10 @@ CHANGES
       functions and other expressions.  (partial progress
       towards [ticket:610])
     
+    - repaired single table inheritance such that you 
+      can single-table inherit from a joined-table inherting
+      mapper without issue [ticket:1036].
+      
     - Fixed "concatenate tuple" bug which could occur with
       Query.order_by() if clause adaption had taken place.
       [ticket:1027]
index ba0644758f9aae6189137dba57599ae5df4b1a92..448d797d06440bc8c2831703760c908253b31c85 100644 (file)
@@ -420,8 +420,9 @@ class Mapper(object):
             # inherit_condition is optional.
             if self.local_table is None:
                 self.local_table = self.inherits.local_table
+                self.mapped_table = self.inherits.mapped_table
                 self.single = True
-            if not self.local_table is self.inherits.local_table:
+            elif not self.local_table is self.inherits.local_table:
                 if self.concrete:
                     self.mapped_table = self.local_table
                     for mapper in self.iterate_to_root():
@@ -1592,7 +1593,8 @@ class Mapper(object):
         for mapper in self.iterate_to_root():
             if mapper is base_mapper:
                 break
-            allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
+            if not mapper.single:
+                allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
 
         return sql.and_(*allconds), param_names
 
index 81223cc02e36445b757e61ed953706a01c3169d4..2241afb0f8d9cbd6781a7e045fd7b33342b94de1 100644 (file)
@@ -2,11 +2,10 @@ import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from testlib import *
+from testlib.fixtures import Base
 
-
-class SingleInheritanceTest(TestBase, AssertsExecutionResults):
-    def setUpAll(self):
-        metadata = MetaData(testing.db)
+class SingleInheritanceTest(ORMTest):
+    def define_tables(self, metadata):
         global employees_table
         employees_table = Table('employees', metadata,
             Column('employee_id', Integer, primary_key=True),
@@ -15,43 +14,27 @@ class SingleInheritanceTest(TestBase, AssertsExecutionResults):
             Column('engineer_info', String(50)),
             Column('type', String(20))
         )
-        employees_table.create()
-    def tearDownAll(self):
-        employees_table.drop()
-    def testbasic(self):
-        class Employee(object):
-            def __init__(self, name):
-                self.name = name
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name
 
+    def test_single_inheritance(self):
+        class Employee(Base):
+            pass
         class Manager(Employee):
-            def __init__(self, name, manager_data):
-                self.name = name
-                self.manager_data = manager_data
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name + " " +  self.manager_data
-
+            pass
         class Engineer(Employee):
-            def __init__(self, name, engineer_info):
-                self.name = name
-                self.engineer_info = engineer_info
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name + " " +  self.engineer_info
-
+            pass
         class JuniorEngineer(Engineer):
             pass
 
-        employee_mapper = mapper(Employee, employees_table, polymorphic_on=employees_table.c.type)
-        manager_mapper = mapper(Manager, inherits=employee_mapper, polymorphic_identity='manager')
-        engineer_mapper = mapper(Engineer, inherits=employee_mapper, polymorphic_identity='engineer')
-        junior_engineer = mapper(JuniorEngineer, inherits=engineer_mapper, polymorphic_identity='juniorengineer')
+        mapper(Employee, employees_table, polymorphic_on=employees_table.c.type)
+        mapper(Manager, inherits=Employee, polymorphic_identity='manager')
+        mapper(Engineer, inherits=Employee, polymorphic_identity='engineer')
+        mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer')
 
         session = create_session()
 
-        m1 = Manager('Tom', 'knows how to manage things')
-        e1 = Engineer('Kurt', 'knows how to hack')
-        e2 = JuniorEngineer('Ed', 'oh that ed')
+        m1 = Manager(name='Tom', manager_data='knows how to manage things')
+        e1 = Engineer(name='Kurt', engineer_info='knows how to hack')
+        e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed')
         session.save(m1)
         session.save(e1)
         session.save(e2)
@@ -62,5 +45,66 @@ class SingleInheritanceTest(TestBase, AssertsExecutionResults):
         assert session.query(Manager).all() == [m1]
         assert session.query(JuniorEngineer).all() == [e2]
 
+class SingleOnJoinedTest(ORMTest):
+    def define_tables(self, metadata):
+        global persons_table, employees_table
+        
+        persons_table = Table('persons', metadata,
+           Column('person_id', Integer, primary_key=True),
+           Column('name', String(50)),
+           Column('type', String(20), nullable=False)
+        )
+
+        employees_table = Table('employees', metadata,
+           Column('person_id', Integer, ForeignKey('persons.person_id'),primary_key=True),
+           Column('employee_data', String(50)),
+           Column('manager_data', String(50)),
+        )
+    
+    def test_single_on_joined(self):
+        class Person(Base):
+            pass
+        class Employee(Person):
+            pass
+        class Manager(Employee):
+            pass
+        
+        mapper(Person, persons_table, polymorphic_on=persons_table.c.type, polymorphic_identity='person')
+        mapper(Employee, employees_table, inherits=Person,polymorphic_identity='engineer')
+        mapper(Manager, inherits=Employee,polymorphic_identity='manager')
+        
+        sess = create_session()
+        sess.save(Person(name='p1'))
+        sess.save(Employee(name='e1', employee_data='ed1'))
+        sess.save(Manager(name='m1', employee_data='ed2', manager_data='md1'))
+        sess.flush()
+        sess.clear()
+        
+        self.assertEquals(sess.query(Person).order_by(Person.person_id).all(), [
+            Person(name='p1'),
+            Employee(name='e1', employee_data='ed1'),
+            Manager(name='m1', employee_data='ed2', manager_data='md1')
+        ])
+        sess.clear()
+
+        self.assertEquals(sess.query(Employee).order_by(Person.person_id).all(), [
+            Employee(name='e1', employee_data='ed1'),
+            Manager(name='m1', employee_data='ed2', manager_data='md1')
+        ])
+        sess.clear()
+
+        self.assertEquals(sess.query(Manager).order_by(Person.person_id).all(), [
+            Manager(name='m1', employee_data='ed2', manager_data='md1')
+        ])
+        sess.clear()
+        
+        def go():
+            self.assertEquals(sess.query(Person).with_polymorphic('*').order_by(Person.person_id).all(), [
+                Person(name='p1'),
+                Employee(name='e1', employee_data='ed1'),
+                Manager(name='m1', employee_data='ed2', manager_data='md1')
+            ])
+        self.assert_sql_count(testing.db, go, 1)
+    
 if __name__ == '__main__':
     testenv.main()