]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
some more tweaks to get more advanced polymorphic stuff to work
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 24 Mar 2006 06:28:27 +0000 (06:28 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 24 Mar 2006 06:28:27 +0000 (06:28 +0000)
examples/polymorph/polymorph.py
examples/polymorph/polymorph2.py [new file with mode: 0644]
lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/mapping/properties.py
lib/sqlalchemy/sql.py

index 804ef4e7a30e488fd0f3d83cc92622fac0339826..d105a64ea28f66774d5b03fbe64d946e1197b6c1 100644 (file)
@@ -23,7 +23,7 @@ people = Table('people', db,
    
 engineers = Table('engineers', db, 
    Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True),
-   Column('description', String(50))).create()
+   Column('special_description', String(50))).create()
    
 managers = Table('managers', db, 
    Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True),
@@ -36,7 +36,7 @@ class Person(object):
         return "Ordinary person %s" % self.name
 class Engineer(Person):
     def __repr__(self):
-        return "Engineer %s, description %s" % (self.name, self.description)
+        return "Engineer %s, description %s" % (self.name, self.special_description)
 class Manager(Person):
     def __repr__(self):
         return "Manager %s, description %s" % (self.name, self.description)
@@ -69,7 +69,7 @@ person_join = select(
                 [people, managers.c.description,column("'manager'").label('type')], 
                 people.c.person_id==managers.c.person_id).union_all(
             select(
-            [people, engineers.c.description, column("'engineer'").label('type')],
+            [people, engineers.c.special_description.label('description'), column("'engineer'").label('type')],
             people.c.person_id==engineers.c.person_id)).alias('pjoin')
             
 
@@ -83,7 +83,9 @@ print "Person selectable:", str(person_join.select(use_labels=True)), "\n"
 class PersonLoader(MapperExtension):
     def create_instance(self, mapper, row, imap, class_):
         if row['pjoin_type'] =='engineer':
-            return Engineer()
+            e = Engineer()
+            e.special_description = row['pjoin_description']
+            return e
         elif row['pjoin_type'] =='manager':
             return Manager()
         else:
@@ -111,8 +113,8 @@ assign_mapper(Company, companies, properties={
 
 c = Company(name='company1')
 c.employees.append(Manager(name='pointy haired boss', description='manager1'))
-c.employees.append(Engineer(name='dilbert', description='engineer1'))
-c.employees.append(Engineer(name='wally', description='engineer2'))
+c.employees.append(Engineer(name='dilbert', special_description='engineer1'))
+c.employees.append(Engineer(name='wally', special_description='engineer2'))
 c.employees.append(Manager(name='jsmith', description='manager2'))
 objectstore.commit()
 
@@ -125,7 +127,7 @@ for e in c.employees:
 print "\n"
 
 dilbert = Engineer.mapper.get_by(name='dilbert')
-dilbert.description = 'hes dibert!'
+dilbert.special_description = 'hes dibert!'
 objectstore.commit()
 
 objectstore.clear()
diff --git a/examples/polymorph/polymorph2.py b/examples/polymorph/polymorph2.py
new file mode 100644 (file)
index 0000000..99ee6c3
--- /dev/null
@@ -0,0 +1,137 @@
+from sqlalchemy import *
+import sys
+
+# this example illustrates a polymorphic load of two classes, where each class has a very 
+# different set of properties
+
+db = create_engine('sqlite://', echo=True, echo_uow=False)
+
+# a table to store companies
+companies = Table('companies', db, 
+   Column('company_id', Integer, primary_key=True),
+   Column('name', String(50))).create()
+
+# we will define an inheritance relationship between the table "people" and "engineers",
+# and a second inheritance relationship between the table "people" and "managers"
+people = Table('people', db, 
+   Column('person_id', Integer, primary_key=True),
+   Column('company_id', Integer, ForeignKey('companies.company_id')),
+   Column('name', String(50))).create()
+   
+engineers = Table('engineers', db, 
+   Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True),
+   Column('status', String(30)),
+   Column('engineer_name', String(50)),
+   Column('primary_language', String(50)),
+  ).create()
+   
+managers = Table('managers', db, 
+   Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True),
+   Column('status', String(30)),
+   Column('manager_name', String(50))
+   ).create()
+
+  
+# create our classes.  The Engineer and Manager classes extend from Person.
+class Person(object):
+    def __repr__(self):
+        return "Ordinary person %s" % self.name
+class Engineer(Person):
+    def __repr__(self):
+        return "Engineer %s, status %s, engineer_name %s, primary_language %s" % (self.name, self.status, self.engineer_name, self.primary_language)
+class Manager(Person):
+    def __repr__(self):
+        return "Manager %s, status %s, manager_name %s" % (self.name, self.status, self.manager_name)
+class Company(object):
+    def __repr__(self):
+        return "Company %s" % self.name
+
+# assign plain vanilla mappers
+assign_mapper(Person, people)
+assign_mapper(Engineer, engineers, inherits=Person.mapper)
+assign_mapper(Manager, managers, inherits=Person.mapper)
+
+# create a union that represents both types of joins.  we have to use
+# nulls to pad out the disparate columns.
+person_join = select(
+                [
+                    people, 
+                    managers.c.status, 
+                    managers.c.manager_name,
+                    null().label('engineer_name'),
+                    null().label('primary_language'),
+                    column("'manager'").label('type')
+                ], 
+                people.c.person_id==managers.c.person_id).union_all(
+            select(
+                [
+                    people, 
+                    engineers.c.status, 
+                    null().label('').label('manager_name'),
+                    engineers.c.engineer_name,
+                    engineers.c.primary_language, 
+                    column("'engineer'").label('type')
+                ],
+            people.c.person_id==engineers.c.person_id)).alias('pjoin')
+            
+    
+# MapperExtension object.
+class PersonLoader(MapperExtension):
+    def create_instance(self, mapper, row, imap, class_):
+        if row[person_join.c.type] =='engineer':
+            return Engineer()
+        elif row[person_join.c.type] =='manager':
+            return Manager()
+        else:
+            return Person()
+            
+    def populate_instance(self, mapper, instance, row, identitykey, imap, isnew):
+        if row[person_join.c.type] =='engineer':
+            Engineer.mapper.populate_instance(instance, row, identitykey, imap, isnew)
+            return False
+        elif row[person_join.c.type] =='manager':
+            Manager.mapper.populate_instance(instance, row, identitykey, imap, isnew)
+            return False
+        else:
+            return True
+            
+        
+
+people_mapper = mapper(Person, person_join, extension=PersonLoader())
+
+assign_mapper(Company, companies, properties={
+    'employees': relation(people_mapper, lazy=False, private=True)
+})
+
+c = Company(name='company1')
+c.employees.append(Manager(name='pointy haired boss', status='AAB', manager_name='manager1'))
+c.employees.append(Engineer(name='dilbert', status='BBA', engineer_name='engineer1', primary_language='java'))
+c.employees.append(Engineer(name='wally', status='CGG', engineer_name='engineer2', primary_language='python'))
+c.employees.append(Manager(name='jsmith', status='ABA', manager_name='manager2'))
+objectstore.commit()
+
+objectstore.clear()
+
+c = Company.get(1)
+for e in c.employees:
+    print e, e._instance_key
+
+print "\n"
+
+dilbert = Engineer.mapper.get_by(name='dilbert')
+dilbert.engineer_name = 'hes dibert!'
+objectstore.commit()
+
+objectstore.clear()
+c = Company.get(1)
+for e in c.employees:
+    print e, e._instance_key
+
+objectstore.delete(c)
+objectstore.commit()
+
+
+managers.drop()
+engineers.drop()
+people.drop()
+companies.drop()
index 5e0f257386761d013ebc1d82e17571a623d78571..5b56358bae7830c44263a4ca8a0499640794fdbf 100644 (file)
@@ -291,7 +291,7 @@ class Mapper(object):
             objectstore.get_session().register_clean(value)
 
         if mappers:
-            result.extend(otherresults)
+            result = [result] + otherresults
         return result
             
     def get(self, *ident):
@@ -837,8 +837,8 @@ class Mapper(object):
         
         # call further mapper properties on the row, to pull further 
         # instances from the row and possibly populate this item.
-        for prop in self.props.values():
-            prop.execute(instance, row, identitykey, imap, isnew)
+        if self.extension.populate_instance(self, instance, row, identitykey, imap, isnew):
+            self.populate_instance(instance, row, identitykey, imap, isnew, translate=False)
 
         if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing):
             if result is not None:
@@ -846,6 +846,17 @@ class Mapper(object):
 
         return instance
 
+    def populate_instance(self, instance, row, identitykey, imap, isnew, translate=True):
+        if translate:
+            newrow = {}
+            for table in self.tables:
+                for c in table.c:
+                    newrow[c] = row[c.key]
+            row = newrow
+            
+        for prop in self.props.values():
+            prop.execute(instance, row, identitykey, imap, isnew)
+        
 class MapperProperty(object):
     """an element attached to a Mapper that describes and assists in the loading and saving 
     of an attribute on an object instance."""
@@ -930,7 +941,8 @@ class MapperExtension(object):
     def append_result(self, mapper, row, imap, result, instance, isnew, populate_existing=False):
         """called when an object instance is being appended to a result list.
         
-        If it returns True, it is assumed that this method handled the appending itself.
+        If this method returns True, it is assumed that the mapper should do the appending, else
+        if this method returns False, it is assumed that the append was handled by this method.
 
         mapper - the mapper doing the operation
         
@@ -956,6 +968,22 @@ class MapperExtension(object):
             return True
         else:
             return self.next.append_result(mapper, row, imap, result, instance, isnew, populate_existing)
+    def populate_instance(self, mapper, instance, row, identitykey, imap, isnew):
+        """called right before the mapper, after creating an instance from a row, passes the row
+        to its MapperProperty objects which are responsible for populating the object's attributes.
+        If this method returns True, it is assumed that the mapper should do the appending, else
+        if this method returns False, it is assumed that the append was handled by this method.
+        
+        Essentially, this method is used to have a different mapper populate the object:
+        
+            def populate_instance(self, mapper, *args):
+                othermapper.populate_instance(*args)
+                return False
+        """
+        if self.next is None:
+            return True
+        else:
+            return self.next.populate_instance(row, imap, result, instance, isnew)
     def before_insert(self, mapper, instance):
         """called before an object instance is INSERTed into its table.
         
index e5bcee78cb36eec3e6ef5d06a3396550a17a6b83..33405423362623c9b0a2661fb57c77863154907f 100644 (file)
@@ -807,11 +807,14 @@ class EagerLoader(PropertyLoader):
                 if map.has_key(key):
                     key = map[key]
                 return self.row[key]
+            def keys(self):
+                return map.keys()
         map = {}        
         for c in self.eagertarget.c:
             parent = self.target._get_col_by_original(c.original)
             map[parent] = c
             map[parent._label] = c
+            map[parent.name] = c
         return DecoratorDict
         
     def _instance(self, row, imap, result_list=None):
index b945587bd4cf719c8d14f6b61ef68b8f7e8dee8c..d0ab4578aae4ffba17aaad5b4c1b57e5eba1afe3 100644 (file)
@@ -13,7 +13,7 @@ from exceptions import *
 import string, re, random
 types = __import__('types')
 
-__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'union', 'union_all', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists']
+__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists']
 
 def desc(column):
     """returns a descending ORDER BY clause element, e.g.:
@@ -705,9 +705,11 @@ class TextClause(ClauseElement):
     def _get_from_objects(self):
         return []
 
-class Null(ClauseElement):
+class Null(ColumnElement):
     """represents the NULL keyword in a SQL statement. public contstructor is the
     null() function."""
+    def __init__(self):
+        self.type = sqltypes.NULLTYPE
     def accept_visitor(self, visitor):
         visitor.visit_null(self)
     def _get_from_objects(self):