]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added a third layer of inheritance to polymorph test
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 8 Jun 2007 02:12:36 +0000 (02:12 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 8 Jun 2007 02:12:36 +0000 (02:12 +0000)
- added some extra logic to mapper to try to convert a "foreign key" holding PK-col
into a non-FK col
- apparently, polymorphic loading can now be achieved with LEFT OUTER JOINs quite easily (i.e. no UNIONs).
this needs to be studied further (i.e. why was I making everyone use UNION ALL all this time)

lib/sqlalchemy/orm/mapper.py
test/orm/inheritance/polymorph.py

index 525028b0198c46d10da93be4468d801379f54f19..7363d9b57e117027f30614f38c8c252fa01a061c 100644 (file)
@@ -545,6 +545,27 @@ class Mapper(object):
                         break
                 else:
                     raise exceptions.ArgumentError("Cant resolve column " + str(col))
+
+            # this step attempts to resolve the column to an equivalent which is not
+            # a foreign key elsewhere.  this helps with joined table inheritance
+            # so that PKs are expressed in terms of the base table which is always
+            # present in the initial select
+            # TODO: this is a little hacky right now, the "tried" list is to prevent
+            # endless loops between cyclical FKs, try to make this cleaner/work better/etc.,
+            # perhaps via topological sort (pick the leftmost item)
+            tried = util.Set()
+            while True:
+                if not len(c.foreign_keys) or c in tried:
+                    break
+                for cc in c.foreign_keys:
+                    cc = cc.column
+                    c2 = self.mapped_table.corresponding_column(cc, raiseerr=False)
+                    if c2 is not None:
+                        c = c2
+                        tried.add(c)
+                        break
+                else:
+                    break
             primary_key.add(c)
                 
         if len(primary_key) == 0:
index 7b14378bc57109daf14563fa7db1c8781d31be14..b166cc19d67502cf57f5ae372a4c0d12caa92f13 100644 (file)
@@ -23,6 +23,10 @@ class Engineer(Person):
 class Manager(Person):
     def __repr__(self):
         return "Manager %s, status %s, manager_name %s" % (self.get_name(), self.status, self.manager_name)
+class Boss(Manager):
+    def __repr__(self):
+        return "Boss %s, status %s, manager_name %s golf swing %s" % (self.get_name(), self.status, self.manager_name, self.golf_swing)
+        
 class Company(object):
     def __init__(self, **kwargs):
         for key, value in kwargs.iteritems():
@@ -32,7 +36,7 @@ class Company(object):
 
 class PolymorphTest(testbase.ORMTest):
     def define_tables(self, metadata):
-        global companies, people, engineers, managers
+        global companies, people, engineers, managers, boss
         
         # a table to store companies
         companies = Table('companies', metadata, 
@@ -60,6 +64,11 @@ class PolymorphTest(testbase.ORMTest):
            Column('manager_name', String(50))
            )
 
+        boss = Table('boss', metadata, 
+            Column('boss_id', Integer, ForeignKey('managers.person_id'), primary_key=True),
+            Column('golf_swing', String(30)),
+            )
+            
         metadata.create_all()
 
 class CompileTest(PolymorphTest):
@@ -197,7 +206,7 @@ class RelationToSubclassTest(PolymorphTest):
 class RoundTripTest(PolymorphTest):
     pass
           
-def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_colprop=False, use_literal_join=False, polymorphic_fetch=None):
+def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_colprop=False, use_literal_join=False, polymorphic_fetch=None, use_outer_joins=False):
     """generates a round trip test.
     
     include_base - whether or not to include the base 'person' type in the union.
@@ -209,19 +218,31 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
         # create a union that represents both types of joins.  
         if not polymorphic_fetch == 'union':
             person_join = None
+            manager_join = None
         elif include_base:
-            person_join = polymorphic_union(
-                {
-                    'engineer':people.join(engineers),
-                    'manager':people.join(managers),
-                    'person':people.select(people.c.type=='person'),
-                }, None, 'pjoin')
+            if use_outer_joins:
+                person_join = people.outerjoin(engineers).outerjoin(managers).outerjoin(boss)
+                manager_join = people.join(managers).outerjoin(boss)
+            else:
+                person_join = polymorphic_union(
+                    {
+                        'engineer':people.join(engineers),
+                        'manager':people.join(managers),
+                        'person':people.select(people.c.type=='person'),
+                    }, None, 'pjoin')
+                
+                manager_join = people.join(managers).outerjoin(boss)
         else:
-            person_join = polymorphic_union(
-                {
-                    'engineer':people.join(engineers),
-                    'manager':people.join(managers),
-                }, None, 'pjoin')
+            if use_outer_joins:
+                person_join = people.outerjoin(engineers).outerjoin(managers).outerjoin(boss)
+                manager_join = people.join(managers).outerjoin(boss)
+            else:
+                person_join = polymorphic_union(
+                    {
+                        'engineer':people.join(engineers),
+                        'manager':people.join(managers),
+                    }, None, 'pjoin')
+                manager_join = people.join(managers).outerjoin(boss)
 
         if redefine_colprop:
             person_mapper = mapper(Person, people, select_table=person_join, polymorphic_fetch=polymorphic_fetch, polymorphic_on=people.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name})
@@ -229,8 +250,10 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
             person_mapper = mapper(Person, people, select_table=person_join, polymorphic_fetch=polymorphic_fetch, polymorphic_on=people.c.type, polymorphic_identity='person')
         
         mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer')
-        mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager')
+        mapper(Manager, managers, inherits=person_mapper, select_table=manager_join, polymorphic_identity='manager')
 
+        mapper(Boss, boss, inherits=Manager, polymorphic_identity='boss')
+        
         if use_literal_join:
             mapper(Company, companies, properties={
                 'employees': relation(Person, lazy=lazy_relation, primaryjoin=people.c.company_id==companies.c.company_id, private=True, 
@@ -295,18 +318,28 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
         session.flush()
         session.clear()
 
+        # save/load some managers/bosses
+        b = Boss(status='BBB', manager_name='boss', golf_swing='fore', **{person_attribute_name:'daboss'})
+        session.save(b)
+        session.flush()
+        session.clear()
+        c = session.query(Manager).all()
+        assert sets.Set([repr(x) for x in c]) == sets.Set(["Manager pointy haired boss, status AAB, manager_name manager1", "Manager jsmith, status ABA, manager_name manager2", "Boss daboss, status BBB, manager_name boss golf swing fore"]), repr([repr(x) for x in c])
+        
         c = session.query(Company).get(id)
         for e in c.employees:
             print e, e._instance_key
 
         session.delete(c)
         session.flush()
-
-    test_roundtrip.__name__ = "test_%s%s%s%s" % (
+        
+        
+    test_roundtrip.__name__ = "test_%s%s%s%s%s" % (
         (lazy_relation and "lazy" or "eager"),
         (include_base and "_inclbase" or ""),
         (redefine_colprop and "_redefcol" or ""),
-        (polymorphic_fetch != 'union' and '_' + polymorphic_fetch or (use_literal_join and "_litjoin" or ""))
+        (polymorphic_fetch != 'union' and '_' + polymorphic_fetch or (use_literal_join and "_litjoin" or "")),
+        (use_outer_joins and '_outerjoins' or '')
     )
     setattr(RoundTripTest, test_roundtrip.__name__, test_roundtrip)
 
@@ -315,8 +348,12 @@ for include_base in [True, False]:
         for redefine_colprop in [True, False]:
             for use_literal_join in [True, False]:
                 for polymorphic_fetch in ['union', 'select', 'deferred']:
-                    generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, polymorphic_fetch)
-                
+                    if polymorphic_fetch == 'union':
+                        for use_outer_joins in [True, False]:
+                            generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, polymorphic_fetch, use_outer_joins)
+                    else:
+                        generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, polymorphic_fetch, False)
+                        
 if __name__ == "__main__":    
     testbase.main()