]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added new 'polymorphic' example. still trying to understand it :) .
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 6 Mar 2006 02:27:13 +0000 (02:27 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 6 Mar 2006 02:27:13 +0000 (02:27 +0000)
fixes to relation to enable it to locate "direction" more consistently with inheritance relationships
more tweaks to parenthesizing subqueries, unions, etc.

examples/polymorph/polymorph.py [new file with mode: 0644]
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/mapping/properties.py
lib/sqlalchemy/sql.py

diff --git a/examples/polymorph/polymorph.py b/examples/polymorph/polymorph.py
new file mode 100644 (file)
index 0000000..b190380
--- /dev/null
@@ -0,0 +1,131 @@
+from sqlalchemy import *
+import sys
+
+# this example illustrates how to create a relationship to a list of objects,
+# where each object in the list has a different type.  The typed objects will
+# extend from a common base class, although this same approach can be used
+# with 
+
+db = create_engine('sqlite://', echo=True, echo_uow=False)
+
+# a table to store companies
+companies = Table('companies', db, 
+   Column('company_id', Integer, primary_key=True),
+   Column('name', String(50))).create()
+
+# we will define an inheritance relationship between the table "people" and "engineers",
+# and a second inheritance relationship between the table "people" and "managers"
+people = Table('people', db, 
+   Column('person_id', Integer, primary_key=True),
+   Column('company_id', Integer, ForeignKey('companies.company_id')),
+   Column('name', String(50))).create()
+   
+engineers = Table('engineers', db, 
+   Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True),
+   Column('description', String(50))).create()
+   
+managers = Table('managers', db, 
+   Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True),
+   Column('description', String(50))).create()
+
+  
+# create our classes.  The Engineer and Manager classes extend from Person.
+class Person(object):
+    def __repr__(self):
+        return "Ordinary person %s" % self.name
+class Engineer(Person):
+    def __repr__(self):
+        return "Engineer %s, description %s" % (self.name, self.description)
+class Manager(Person):
+    def __repr__(self):
+        return "Manager %s, description %s" % (self.name, self.description)
+class Company(object):
+    def __repr__(self):
+        return "Company %s" % self.name
+
+# next we assign Person mappers.  Since these are the first mappers we are
+# creating for these classes, they automatically become the "primary mappers", which
+# define the dependency relationships between the classes, so we do a straight
+# inheritance setup, i.e. no modifications to how objects are loaded or anything like that.
+assign_mapper(Person, people)
+assign_mapper(Engineer, engineers, inherits=Person.mapper)
+assign_mapper(Manager, managers, inherits=Person.mapper)
+
+# next, we define a query that is going to load Managers and Engineers in one shot.
+# this query is tricky since the managers and engineers tables contain the same "description" column,
+# so we set up a full blown select() statement that uses aliases for the description
+# column.  The select() statement is also given an alias 'pjoin', since the mapper requires
+# that all Selectables have a name.  
+#
+# TECHNIQUE - when you want to load a certain set of objects from a in one query, all the
+# columns in the Selectable must have unique names.  Dont worry about mappers at this point,
+# just worry about making a query where if you were to view the results, you could tell everything
+# you need to know from each row how to construct an object instance from it.  this is the
+# essence of "resultset-based-mapping", which is the core ideology of SQLAlchemy.
+#
+person_join = select(
+                [people, managers.c.description,column("'manager'").label('type')], 
+                people.c.person_id==managers.c.person_id).union(
+            select(
+            [people, engineers.c.description, column("'engineer'").label('type')],
+            people.c.person_id==engineers.c.person_id)).alias('pjoin')
+
+
+# lets print out what this Selectable looks like.  The mapper is going to take the selectable and
+# Select off of it, with the flag "use_labels" which indicates to prefix column names with the table
+# name.  So here is what our mapper will see:
+print "Person selectable:", str(person_join.select(use_labels=True)), "\n"
+
+
+# MapperExtension object.
+class PersonLoader(MapperExtension):
+    def create_instance(self, mapper, row, imap, class_):
+        if row['pjoin_type'] =='engineer':
+            return Engineer()
+        elif row['pjoin_type'] =='manager':
+            return Manager()
+        else:
+            return Person()
+ext = PersonLoader()
+
+# set up the polymorphic mapper, which maps the person_join we set up to
+# the Person class, using an instance of PersonLoader.  Note that even though 
+# this mapper is against Person, its not going to screw up the normal operation 
+# of the Person object since its not the "primary" mapper.  In reality, we could even 
+# make this mapper against some other class we dont care about since the creation of
+# objects is hardcoded.
+people_mapper = mapper(Person, person_join, extension=ext)
+
+assign_mapper(Company, companies, properties={
+    'employees': relation(people_mapper),
+    'engineers': relation(Engineer, private=True),
+    'managers':relation(Manager, private=True)
+})
+
+
+c = Company(name='company1')
+c.employees.append(Manager(name='pointy haired boss', description='manager1'))
+c.employees.append(Engineer(name='dilbert', description='engineer1'))
+c.employees.append(Engineer(name='wally', description='engineer2'))
+c.employees.append(Manager(name='jsmith', description='manager2'))
+objectstore.commit()
+
+objectstore.clear()
+
+c = Company.get(1)
+for e in c.employees:
+    print e, e._instance_key
+
+print "\n"
+
+dilbert = c.employees[1]
+dilbert.description = 'hes dibert!'
+objectstore.commit()
+
+objectstore.clear()
+c = Company.get(1)
+for e in c.employees:
+    print e, e._instance_key
+
+objectstore.delete(c)
+objectstore.commit()
\ No newline at end of file
index 64715cb4f16d4a96b403b82664d9632806261d1f..118f9f8094359dc18420538edf77c88f422b45f5 100644 (file)
@@ -240,7 +240,10 @@ class ANSICompiler(sql.Compiled):
         text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ")
         for tup in cs.clauses:
             text += " " + tup[0] + " " + self.get_str(tup[1])
-        self.strings[cs] = text
+        if cs.parens:
+            self.strings[cs] = "(" + text + ")"
+        else:
+            self.strings[cs] = text
         self.froms[cs] = "(" + text + ")"
 
     def visit_binary(self, binary):
@@ -368,7 +371,7 @@ class ANSICompiler(sql.Compiled):
                 
         text += self.visit_select_postclauses(select)
  
-        if getattr(select, 'useparens', False):
+        if getattr(select, 'parens', False):
             self.strings[select] = "(" + text + ")"
         else:
             self.strings[select] = text
index 6a7eb9659db804f27ceb675a9f55d4f3555d595c..7b4595fb05303b42ed9eb2ffb66409c86c48cfdd 100644 (file)
@@ -952,7 +952,6 @@ def object_mapper(object):
 
 def class_mapper(class_):
     """given a class, returns the primary Mapper associated with the class."""
-    return mapper_registry[class_]
     try:
         return mapper_registry[class_]
     except KeyError:
index 023e44bf7b95041ff4a302a0d982d41bc1ba5f79..0f83568eddb892e6252ca1c2df649cf19f4b4830 100644 (file)
@@ -229,7 +229,8 @@ class PropertyLoader(MapperProperty):
         
     def _get_direction(self):
         """determines our 'direction', i.e. do we represent one to many, many to many, etc."""
-#        print self.key, repr(self.parent.table.name), repr(self.parent.primarytable.name), repr(self.foreignkey.table.name)
+        #print self.key, repr(self.parent.table.name), repr(self.parent.primarytable.name), repr(self.foreignkey.table.name), repr(self.target), repr(self.foreigntable.name)
+        
         if self.parent.table is self.target:
             if self.foreignkey.primary_key:
                 return PropertyLoader.MANYTOONE
@@ -237,9 +238,9 @@ class PropertyLoader(MapperProperty):
                 return PropertyLoader.ONETOMANY
         elif self.secondaryjoin is not None:
             return PropertyLoader.MANYTOMANY
-        elif self.foreigntable == self.target:
+        elif self.foreigntable is self.target or self.foreigntable in self.mapper.tables:
             return PropertyLoader.ONETOMANY
-        elif self.foreigntable == self.parent.table:
+        elif self.foreigntable is self.parent.table or self.foreigntable in self.parent.tables:
             return PropertyLoader.MANYTOONE
         else:
             raise ArgumentError("Cant determine relation direction")
index d4d059d6a730e4ad806f9e345153c55ff0b1b4d9..a6ddf8cb97545fd695fd0ef20a5d75d98e4833d8 100644 (file)
@@ -454,7 +454,7 @@ class CompareMixin(object):
             # assume *other is a list of selects.
             # so put them in a UNION.  if theres only one, you just get one SELECT 
             # statement out of it.
-            return self._compare('IN', union(*other))
+            return self._compare('IN', union(parens=True, *other))
     def startswith(self, other):
         return self._compare('LIKE', str(other) + "%")
     def endswith(self, other):
@@ -1123,6 +1123,7 @@ class CompoundSelect(SelectBaseMixin, FromClause):
         self.keyword = keyword
         self.selects = selects
         self.use_labels = kwargs.pop('use_labels', False)
+        self.parens = kwargs.pop('parens', False)
         self.oid_column = selects[0].oid_column
         for s in self.selects:
             s.group_by(None)
@@ -1209,7 +1210,8 @@ class Select(SelectBaseMixin, FromClause):
         def visit_compound_select(self, cs):
             self.visit_select(cs)
             for s in cs.selects:
-                s.useparens = False
+                s.parens = False
+            print "BUT", id(cs), cs.parens
         def visit_column(self, c):pass
         def visit_table(self, c):pass
         def visit_select(self, select):
@@ -1217,7 +1219,7 @@ class Select(SelectBaseMixin, FromClause):
                 return
             select.is_where = self.is_where
             select.issubquery = True
-            select.useparens = True
+            select.parens = True
             if getattr(select, '_correlated', None) is None:
                 select._correlated = self.select._froms