]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- joined-table inheritance will now generate the primary key
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 23 Jul 2007 22:20:44 +0000 (22:20 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 23 Jul 2007 22:20:44 +0000 (22:20 +0000)
columns of all inherited classes against the root table of the
join only.  This implies that each row in the root table is distinct
to a single instance.  If for some rare reason this is not desireable,
explicit primary_key settings on individual mappers will override it.

- When "polymorphic" flags are used with joined-table or single-table
inheritance, all identity keys are generated against the root class
of the inheritance hierarchy; this allows query.get() to work
polymorphically using the same caching semantics as a non-polymorphic get.
note that this currently does not work with concrete inheritance.

CHANGES
examples/polymorph/concrete.py
examples/polymorph/single.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/sql_util.py
test/orm/inheritance/basic.py
test/orm/inheritance/concrete.py

diff --git a/CHANGES b/CHANGES
index 4834e366269b62c1e02101767ea47675c617aa8c..092ee176ec3651a1fc4cde81725cff4dc5087558 100644 (file)
--- a/CHANGES
+++ b/CHANGES
       should "collapse" into a single-valued (or fewer-valued) primary key.
       fixes things like [ticket:611].
 
+    - joined-table inheritance will now generate the primary key
+      columns of all inherited classes against the root table of the 
+      join only.  This implies that each row in the root table is distinct
+      to a single instance.  If for some rare reason this is not desireable,
+      explicit primary_key settings on individual mappers will override it.
+      
+    - When "polymorphic" flags are used with joined-table or single-table
+      inheritance, all identity keys are generated against the root class 
+      of the inheritance hierarchy; this allows query.get() to work 
+      polymorphically using the same caching semantics as a non-polymorphic get.
+      note that this currently does not work with concrete inheritance.
+      
     - secondary inheritance loading: polymorphic mappers can be
       constructed *without* a select_table argument. inheriting mappers
       whose tables were not represented in the initial load will issue a
index 593d3f4805cac4e0d6342d80e57bd2114abac3b6..5f12e9a3d73c67467de9d51045bd387f5e75c1f6 100644 (file)
@@ -1,4 +1,5 @@
 from sqlalchemy import *
+from sqlalchemy.orm import *
 
 metadata = MetaData()
 
@@ -49,7 +50,7 @@ manager_mapper = mapper(Manager, managers_table, inherits=employee_mapper, concr
 engineer_mapper = mapper(Engineer, engineers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='engineer')
 
 
-session = create_session(bind_to=engine)
+session = create_session(bind=engine)
 
 m1 = Manager("pointy haired boss", "manager1")
 e1 = Engineer("wally", "engineer1")
index dcdb3c8906bd1172d6fd1befcb075064a4edddd2..61809a05c15d6f71578c0c044ef0d287069a5c50 100644 (file)
@@ -1,4 +1,5 @@
 from sqlalchemy import *
+from sqlalchemy.orm import *
 
 metadata = MetaData('sqlite://', echo='debug')
 
index 097c906abfc1b4e1b7f0e2b2344b0520180946d9..555c1990e55fcb54519616b6ea3a3e55e22d9030 100644 (file)
@@ -469,8 +469,17 @@ class Mapper(object):
                 self.mapped_table = self.local_table
             if self.polymorphic_identity is not None:
                 self.inherits._add_polymorphic_mapping(self.polymorphic_identity, self)
-            if self.polymorphic_on is None and self.inherits.polymorphic_on is not None:
-                self.polymorphic_on = self.mapped_table.corresponding_column(self.inherits.polymorphic_on, keys_ok=True, raiseerr=False)
+                if self.polymorphic_on is None:
+                    if self.inherits.polymorphic_on is not None:
+                        self.polymorphic_on = self.mapped_table.corresponding_column(self.inherits.polymorphic_on, keys_ok=True, raiseerr=False)
+                    else:
+                        raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity))
+
+            if self.polymorphic_identity is not None and not self.concrete:
+                self._identity_class = self.inherits._identity_class
+            else:
+                self._identity_class = self.class_
+                
             if self.order_by is False:
                 self.order_by = self.inherits.order_by
             self.polymorphic_map = self.inherits.polymorphic_map
@@ -480,8 +489,11 @@ class Mapper(object):
             self._synchronizer = None
             self.mapped_table = self.local_table
             if self.polymorphic_identity is not None:
+                if self.polymorphic_on is None:
+                    raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity))
                 self._add_polymorphic_mapping(self.polymorphic_identity, self)
-
+            self._identity_class = self.class_
+            
         if self.mapped_table is None:
             raise exceptions.ArgumentError("Mapper '%s' does not have a mapped_table specified.  (Are you using the return value of table.create()?  It no longer has a return value.)" % str(self))
 
@@ -540,58 +552,60 @@ class Mapper(object):
         if len(self.pks_by_table[self.mapped_table]) == 0:
             raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name))
 
-        # create the "primary_key" for this mapper.  this will flatten "equivalent" primary key columns
-        # into one column, where "equivalent" means that one column references the other via foreign key, or
-        # multiple columns that all reference a common parent column.  it will also resolve the column
-        # against the "mapped_table" of this mapper.
-        equivalent_columns = self._get_equivalent_columns()
+        if self.inherits is not None and not self.concrete and not self.primary_key_argument:
+            self.primary_key = self.inherits.primary_key
+            self._get_clause = self.inherits._get_clause
+        else:
+            # create the "primary_key" for this mapper.  this will flatten "equivalent" primary key columns
+            # into one column, where "equivalent" means that one column references the other via foreign key, or
+            # multiple columns that all reference a common parent column.  it will also resolve the column
+            # against the "mapped_table" of this mapper.
+            equivalent_columns = self._get_equivalent_columns()
         
-        primary_key = sql.ColumnSet()
-
-        for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
-            #primary_key.add(col)
-            #continue
-            c = self.mapped_table.corresponding_column(col, raiseerr=False)
-            if c is None:
-                for cc in equivalent_columns[col]:
-                    c = self.mapped_table.corresponding_column(cc, raiseerr=False)
-                    if c is not None:
+            primary_key = sql.ColumnSet()
+
+            for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
+                c = self.mapped_table.corresponding_column(col, raiseerr=False)
+                if c is None:
+                    for cc in equivalent_columns[col]:
+                        c = self.mapped_table.corresponding_column(cc, raiseerr=False)
+                        if c is not None:
+                            break
+                    else:
+                        raise exceptions.ArgumentError("Cant resolve column " + str(col))
+
+                # this step attempts to resolve the column to an equivalent which is not
+                # a foreign key elsewhere.  this helps with joined table inheritance
+                # so that PKs are expressed in terms of the base table which is always
+                # present in the initial select
+                # TODO: this is a little hacky right now, the "tried" list is to prevent
+                # endless loops between cyclical FKs, try to make this cleaner/work better/etc.,
+                # perhaps via topological sort (pick the leftmost item)
+                tried = util.Set()
+                while True:
+                    if not len(c.foreign_keys) or c in tried:
                         break
-                else:
-                    raise exceptions.ArgumentError("Cant resolve column " + str(col))
-
-            # this step attempts to resolve the column to an equivalent which is not
-            # a foreign key elsewhere.  this helps with joined table inheritance
-            # so that PKs are expressed in terms of the base table which is always
-            # present in the initial select
-            # TODO: this is a little hacky right now, the "tried" list is to prevent
-            # endless loops between cyclical FKs, try to make this cleaner/work better/etc.,
-            # perhaps via topological sort (pick the leftmost item)
-            tried = util.Set()
-            while True:
-                if not len(c.foreign_keys) or c in tried:
-                    break
-                for cc in c.foreign_keys:
-                    cc = cc.column
-                    c2 = self.mapped_table.corresponding_column(cc, raiseerr=False)
-                    if c2 is not None:
-                        c = c2
-                        tried.add(c)
+                    for cc in c.foreign_keys:
+                        cc = cc.column
+                        c2 = self.mapped_table.corresponding_column(cc, raiseerr=False)
+                        if c2 is not None:
+                            c = c2
+                            tried.add(c)
+                            break
+                    else:
                         break
-                else:
-                    break
-            primary_key.add(c)
+                primary_key.add(c)
                 
-        if len(primary_key) == 0:
-            raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name))
+            if len(primary_key) == 0:
+                raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name))
 
-        self.primary_key = primary_key
-        self.__log("Identified primary key columns: " + str(primary_key))
+            self.primary_key = primary_key
+            self.__log("Identified primary key columns: " + str(primary_key))
         
-        _get_clause = sql.and_()
-        for primary_key in self.primary_key:
-            _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type_=primary_key.type, unique=True))
-        self._get_clause = _get_clause
+            _get_clause = sql.and_()
+            for primary_key in self.primary_key:
+                _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type_=primary_key.type, unique=True))
+            self._get_clause = _get_clause
 
     def _get_equivalent_columns(self):
         """Create a map of all *equivalent* columns, based on
@@ -996,7 +1010,7 @@ class Mapper(object):
           dictionary corresponding result-set ``ColumnElement``
           instances to their values within a row.
         """
-        return (self.class_, tuple([row[column] for column in self.primary_key]), self.entity_name)
+        return (self._identity_class, tuple([row[column] for column in self.primary_key]), self.entity_name)
 
     def identity_key_from_primary_key(self, primary_key):
         """Return an identity-map key for use in storing/retrieving an
@@ -1005,7 +1019,7 @@ class Mapper(object):
         primary_key
           A list of values indicating the identifier.
         """
-        return (self.class_, tuple(util.to_list(primary_key)), self.entity_name)
+        return (self._identity_class, tuple(util.to_list(primary_key)), self.entity_name)
 
     def identity_key_from_instance(self, instance):
         """Return the identity key for the given instance, based on
index f31d0013c2630cf252383760f62ee458858f1fa6..d91fbe4b522b8fafc8f60cd8a65134d8d1e25d23 100644 (file)
@@ -147,7 +147,6 @@ class AbstractClauseProcessor(sql.NoColumnVisitor):
     def visit_clauselist(self, clist):
         for i in range(0, len(clist.clauses)):
             n = self.convert_element(clist.clauses[i])
-            print "CONVERTEING CLAUSELIST W ID", id(clist)
             if n is not None:
                 clist.clauses[i] = n
     
index c6cd43f439d46540491bf76f32f38f0a2cca8430..be623e1b876b8379988266526f4a8631fda15bd0 100644 (file)
@@ -62,9 +62,93 @@ class O2MTest(ORMTest):
         self.assert_(compare == result)
         self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1')
 
-class AddPropTest(ORMTest):
-    """testing that construction of inheriting mappers works regardless of when extra properties
-    are added to the superclass mapper"""
+class GetTest(ORMTest):    
+    def define_tables(self, metadata):
+        global foo, bar, blub
+        foo = Table('foo', metadata,
+            Column('id', Integer, Sequence('foo_seq'), primary_key=True),
+            Column('type', String(30)),
+            Column('data', String(20)))
+
+        bar = Table('bar', metadata,
+            Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
+            Column('data', String(20)))
+
+        blub = Table('blub', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('foo_id', Integer, ForeignKey('foo.id')),
+            Column('bar_id', Integer, ForeignKey('bar.id')),
+            Column('data', String(20)))
+    
+    def create_test(polymorphic):
+        def test_get(self):
+            class Foo(object):
+                pass
+
+            class Bar(Foo):
+                pass
+        
+            class Blub(Bar):
+                pass
+
+            if polymorphic:
+                mapper(Foo, foo, polymorphic_on=foo.c.type, polymorphic_identity='foo')
+                mapper(Bar, bar, inherits=Foo, polymorphic_identity='bar')
+                mapper(Blub, blub, inherits=Bar, polymorphic_identity='blub')
+            else:
+                mapper(Foo, foo)
+                mapper(Bar, bar, inherits=Foo)
+                mapper(Blub, blub, inherits=Bar)
+        
+            sess = create_session()
+            f = Foo()
+            b = Bar()
+            bl = Blub()
+            sess.save(f)
+            sess.save(b)
+            sess.save(bl)
+            sess.flush()
+            
+            if polymorphic:
+                def go():
+                    assert sess.query(Foo).get(f.id) == f
+                    assert sess.query(Foo).get(b.id) == b
+                    assert sess.query(Foo).get(bl.id) == bl
+                    assert sess.query(Bar).get(b.id) == b
+                    assert sess.query(Bar).get(bl.id) == bl
+                    assert sess.query(Blub).get(bl.id) == bl
+            
+                self.assert_sql_count(testbase.db, go, 0)
+            else:
+                # this is testing the 'wrong' behavior of using get() 
+                # polymorphically with mappers that are not configured to be
+                # polymorphic.  the important part being that get() always
+                # returns an instance of the query's type.
+                def go():
+                    assert sess.query(Foo).get(f.id) == f
+                    
+                    bb = sess.query(Foo).get(b.id)
+                    assert isinstance(b, Foo) and bb.id==b.id
+                    
+                    bll = sess.query(Foo).get(bl.id)
+                    assert isinstance(bll, Foo) and bll.id==bl.id
+                    
+                    assert sess.query(Bar).get(b.id) == b
+                    
+                    bll = sess.query(Bar).get(bl.id)
+                    assert isinstance(bll, Bar) and bll.id == bl.id
+                    
+                    assert sess.query(Blub).get(bl.id) == bl
+            
+                self.assert_sql_count(testbase.db, go, 3)
+                
+        return test_get
+        
+    test_get_polymorphic = create_test(True)
+    test_get_nonpolymorphic = create_test(False)
+
+
+class ConstructionTest(ORMTest):
     def define_tables(self, metadata):
         global content_type, content, product
         content_type = Table('content_type', metadata, 
@@ -72,7 +156,8 @@ class AddPropTest(ORMTest):
             )
         content = Table('content', metadata,
             Column('id', Integer, primary_key=True),
-            Column('content_type_id', Integer, ForeignKey('content_type.id'))
+            Column('content_type_id', Integer, ForeignKey('content_type.id')),
+            Column('type', String(30))
             )
         product = Table('product', metadata, 
             Column('id', Integer, ForeignKey('content.id'), primary_key=True)
@@ -86,11 +171,15 @@ class AddPropTest(ORMTest):
         content_types = mapper(ContentType, content_type)
         contents = mapper(Content, content, properties={
             'content_type':relation(content_types)
-        })
-        #contents.add_property('content_type', relation(content_types)) #adding this makes the inheritance stop working
-        # shouldnt throw exception
-        products = mapper(Product, product, inherits=contents)
-        # TODO: assertion ??
+        }, polymorphic_identity='contents')
+
+        products = mapper(Product, product, inherits=contents, polymorphic_identity='products')
+        
+        try:
+            compile_mappers()
+            assert False
+        except exceptions.ArgumentError, e:
+            assert str(e) == "Mapper 'Mapper|Content|content' specifies a polymorphic_identity of 'contents', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument"
 
     def testbackref(self):
         """tests adding a property to the superclass mapper"""
@@ -98,8 +187,8 @@ class AddPropTest(ORMTest):
         class Content(object): pass
         class Product(Content): pass
 
-        contents = mapper(Content, content)
-        products = mapper(Product, product, inherits=contents)
+        contents = mapper(Content, content, polymorphic_on=content.c.type, polymorphic_identity='content')
+        products = mapper(Product, product, inherits=contents, polymorphic_identity='product')
         content_types = mapper(ContentType, content_type, properties={
             'content':relation(contents, backref='contenttype')
         })
@@ -278,12 +367,7 @@ class DistinctPKTest(ORMTest):
     def test_implicit(self):
         person_mapper = mapper(Person, person_table)
         mapper(Employee, employee_table, inherits=person_mapper)
-        try:
-            print class_mapper(Employee).primary_key
-            assert list(class_mapper(Employee).primary_key) == [person_table.c.id, employee_table.c.id]
-            assert False
-        except RuntimeWarning, e:
-            assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'.  Use explicit properties to give each column its own mapped attribute name."
+        assert list(class_mapper(Employee).primary_key) == [person_table.c.id]
 
     def test_explicit_props(self):
         person_mapper = mapper(Person, person_table)
index 167b25256da9bbcb38d06fd9646c0cffca546af6..d95a96da5f3cf2e401597ba8a4c1c48906b001bb 100644 (file)
@@ -54,6 +54,7 @@ class ConcreteTest1(ORMTest):
         session.flush()
         session.clear()
 
+        print set([repr(x) for x in session.query(Employee).select()])
         assert set([repr(x) for x in session.query(Employee).select()]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"])
         assert set([repr(x) for x in session.query(Manager).select()]) == set(["Manager Tom knows how to manage things"])
         assert set([repr(x) for x in session.query(Engineer).select()]) == set(["Engineer Kurt knows how to hack"])