]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
relationships no longer compile against the "selectable" mapper (i.e. the polymorphic...
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 23 Jan 2007 04:05:07 +0000 (04:05 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 23 Jan 2007 04:05:07 +0000 (04:05 +0000)
CHANGES
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql_util.py
test/orm/inheritance5.py
test/orm/poly_linked_list.py
test/orm/polymorph.py

diff --git a/CHANGES b/CHANGES
index 3f44ae04c592c9fa6718b3dfc7c359200802657d..6a38f165693a55c31dcbbe420b605bfcfb0add0c 100644 (file)
--- a/CHANGES
+++ b/CHANGES
   - poked the first hole in the can of worms: saying query.select_by(somerelationname=someinstance)
   will create the join of the primary key columns represented by "somerelationname"'s mapper to the
   actual primary key in "someinstance".
-  - some deeper error checking when compiling relations, to detect an ambiguous "primaryjoin"
-  in the case that both sides of the relationship have foreign key references in the primary
-  join condition.  also tightened down conditions used to locate "relation direction", associating
-  the "foreignkey" of the relationship with the "primaryjoin"
+  - reworked how relations interact with "polymorphic" mappers, i.e. mappers that have a select_table
+  as well as polymorphic flags.  better determination of proper join conditions, interaction with user-
+  defined join conditions, and support for self-referential polymorphic mappers.
+  - related to polymorphic mapping relations, some deeper error checking when compiling relations, 
+  to detect an ambiguous "primaryjoin" in the case that both sides of the relationship have foreign key 
+  references in the primary join condition.  also tightened down conditions used to locate "relation direction", 
+  associating the "foreignkey" of the relationship with the "primaryjoin"
   - a little bit of improvement to the concept of a "concrete" inheritance mapping, though that concept
   is not well fleshed out yet (added test case to support concrete mappers on top of a polymorphic base).
   - fix to "proxy=True" behavior on synonym()
index 4fd9a3e9b542eaf97be7c54c73934e5ae01a4e86..bc74fcafd336fa1f2bd72bce8311bd1a2bd8e2a5 100644 (file)
@@ -193,8 +193,9 @@ class PropertyLoader(StrategizedProperty):
         else:
             raise exceptions.ArgumentError("relation '%s' expects a class or a mapper argument (received: %s)" % (self.key, type(self.argument)))
             
-        self.mapper = self.mapper.get_select_mapper()._check_compile()
-            
+        # insure the "select_mapper", if different from the regular target mapper, is compiled.
+        self.mapper.get_select_mapper()._check_compile()
+           
         if self.association is not None:
             if isinstance(self.association, type):
                 self.association = mapper.class_mapper(self.association, compile=False)._check_compile()
@@ -220,6 +221,19 @@ class PropertyLoader(StrategizedProperty):
                     self.primaryjoin = sql.join(self.parent.unjoined_table, self.target).onclause
         except exceptions.ArgumentError, e:
             raise exceptions.ArgumentError("Error determining primary and/or secondary join for relationship '%s' between mappers '%s' and '%s'.  If the underlying error cannot be corrected, you should specify the 'primaryjoin' (and 'secondaryjoin', if there is an association table present) keyword arguments to the relation() function (or for backrefs, by specifying the backref using the backref() function with keyword arguments) to explicitly specify the join conditions.  Nested error is \"%s\"" % (self.key, self.parent, self.mapper, str(e)))
+
+        # if using polymorphic mapping, the join conditions must be agasint the base tables of the mappers,
+        # as the loader strategies expect to be working with those now (they will adapt the join conditions
+        # to the "polymorphic" selectable as needed).  since this is an API change, put an explicit check/
+        # error message in case its the "old" way.
+        if self.mapper.select_table is not self.mapper.mapped_table:
+            vis = sql_util.ColumnsInClause(self.mapper.select_table)
+            self.primaryjoin.accept_visitor(vis)
+            if self.secondaryjoin:
+                self.secondaryjoin.accept_visitor(vis)
+            if vis.result:
+                raise exceptions.ArgumentError("In relationship '%s' between mappers '%s' and '%s', primary and secondary join conditions must not include columns from the polymorphic 'select_table' argument as of SA release 0.3.4.  Construct join conditions using the base tables of the related mappers." % (self.key, self.parent, self.mapper))
+
         # if the foreign key wasnt specified and theres no assocaition table, try to figure
         # out who is dependent on who. we dont need all the foreign keys represented in the join,
         # just one of them.
@@ -280,10 +294,10 @@ class PropertyLoader(StrategizedProperty):
                 else:
                     return sync.MANYTOONE
         else:
-            onetomany = len([c for c in self.foreignkey if self.mapper.unjoined_table.corresponding_column(c, False, require_exact=True) is not None])
-            manytoone = len([c for c in self.foreignkey if self.parent.unjoined_table.corresponding_column(c, False, require_exact=True) is not None])
+            onetomany = len([c for c in self.foreignkey if self.mapper.unjoined_table.corresponding_column(c, False) is not None])
+            manytoone = len([c for c in self.foreignkey if self.parent.unjoined_table.corresponding_column(c, False) is not None])
             if not onetomany and not manytoone:
-                raise exceptions.ArgumentError("Cant determine relation direction for '%s' on mapper '%s' with primary join '%s' - foreign key columns are not present in neither the parent nor the child's mapped tables" %(self.key, str(self.parent), str(self.primaryjoin)))
+                raise exceptions.ArgumentError("Cant determine relation direction for '%s' on mapper '%s' with primary join '%s' - foreign key columns are not present in neither the parent nor the child's mapped tables" %(self.key, str(self.parent), str(self.primaryjoin)) +  str(self.foreignkey))
             elif onetomany and manytoone:
                 raise exceptions.ArgumentError("Cant determine relation direction for '%s' on mapper '%s' with primary join '%s' - foreign key columns are present in both the parent and the child's mapped tables.  Specify 'foreignkey' argument." %(self.key, str(self.parent), str(self.primaryjoin)))
             elif onetomany:
index 88a22bbcdd165e5c36647fc7ded7fbe91669a60d..2a33c4ff07d6972bf89f89beb7af1adfc9cbb325 100644 (file)
@@ -130,7 +130,10 @@ class AbstractRelationLoader(LoaderStrategy):
         self.secondary = self.parent_property.secondary
         self.foreignkey = self.parent_property.foreignkey
         self.mapper = self.parent_property.mapper
+        self.select_mapper = self.mapper.get_select_mapper()
         self.target = self.parent_property.target
+        self.select_table = self.parent_property.mapper.select_table
+        self.loads_polymorphic = self.target is not self.select_table
         self.uselist = self.parent_property.uselist
         self.cascade = self.parent_property.cascade
         self.attributeext = self.parent_property.attributeext
@@ -160,7 +163,7 @@ NoLoader.logger = logging.class_logger(NoLoader)
 class LazyLoader(AbstractRelationLoader):
     def init(self):
         super(LazyLoader, self).init()
-        (self.lazywhere, self.lazybinds, self.lazyreverse) = self._create_lazy_clause(self.parent.unjoined_table, self.primaryjoin, self.secondaryjoin, self.foreignkey, self.remote_side)
+        (self.lazywhere, self.lazybinds, self.lazyreverse) = self._create_lazy_clause(self.parent.unjoined_table, self.primaryjoin, self.secondaryjoin, self.foreignkey, self.remote_side, self.mapper.select_table)
         # determine if our "lazywhere" clause is the same as the mapper's
         # get() clause.  then we can just use mapper.get()
         self.use_get = not self.uselist and query.Query(self.mapper)._get_clause.compare(self.lazywhere)
@@ -244,11 +247,11 @@ class LazyLoader(AbstractRelationLoader):
                 # to load data into it.
                 sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
 
-    def _create_lazy_clause(self, table, primaryjoin, secondaryjoin, foreignkey, remote_side):
+    def _create_lazy_clause(self, table, primaryjoin, secondaryjoin, foreignkey, remote_side, select_table):
         binds = {}
         reverse = {}
         def column_in_table(table, column):
-            return table.corresponding_column(column, raiseerr=False, keys_ok=False, require_exact=True) is not None
+            return table.corresponding_column(column, raiseerr=False, keys_ok=False) is not None
 
         if remote_side is None or len(remote_side) == 0:
             remote_side = foreignkey
@@ -262,6 +265,13 @@ class LazyLoader(AbstractRelationLoader):
                     columns.append(c)
             expr.accept_visitor(FindColumnInColumnClause())
             return len(columns) and columns[0] or None
+        
+        def col_in_collection(column, collection):
+            for c in collection:
+                if column.shares_lineage(c):
+                    return True
+            else:
+                return False
                 
         def bind_label():
             return "lazy_" + hex(random.randint(0, 65535))[2:]
@@ -271,13 +281,13 @@ class LazyLoader(AbstractRelationLoader):
             if leftcol is None or rightcol is None:
                 return
             circular = leftcol.table is rightcol.table
-            if ((not circular and column_in_table(table, leftcol)) or (circular and rightcol in remote_side)):
+            if ((not circular and column_in_table(table, leftcol)) or (circular and col_in_collection(rightcol, remote_side))):
                 col = leftcol
                 binary.left = binds.setdefault(leftcol,
                         sql.bindparam(bind_label(), None, shortname=leftcol.name, type=binary.right.type))
                 reverse[rightcol] = binds[col]
 
-            if (leftcol is not rightcol) and ((not circular and column_in_table(table, rightcol)) or (circular and leftcol in remote_side)):
+            if (leftcol is not rightcol) and ((not circular and column_in_table(table, rightcol)) or (circular and col_in_collection(leftcol, remote_side))):
                 col = rightcol
                 binary.right = binds.setdefault(rightcol,
                         sql.bindparam(bind_label(), None, shortname=rightcol.name, type=binary.left.type))
@@ -286,9 +296,15 @@ class LazyLoader(AbstractRelationLoader):
         lazywhere = primaryjoin.copy_container()
         li = mapperutil.BinaryVisitor(visit_binary)
         lazywhere.accept_visitor(li)
+        
         if secondaryjoin is not None:
+            secondaryjoin = secondaryjoin.copy_container()
+            secondaryjoin.accept_visitor(sql_util.ClauseAdapter(select_table))
             lazywhere = sql.and_(lazywhere, secondaryjoin)
-        LazyLoader.logger.debug("create_lazy_clause " + str(lazywhere))
+        else:
+            lazywhere.accept_visitor(sql_util.ClauseAdapter(select_table))
+            
+        LazyLoader.logger.info("create_lazy_clause " + str(lazywhere))
         return (lazywhere, binds, reverse)
 
 LazyLoader.logger = logging.class_logger(LazyLoader)
@@ -299,7 +315,7 @@ class EagerLoader(AbstractRelationLoader):
     """loads related objects inline with a parent query."""
     def init(self):
         super(EagerLoader, self).init()
-        if self.parent.isa(self.mapper):
+        if self.parent.isa(self.select_mapper):
             raise exceptions.ArgumentError("Error creating eager relationship '%s' on parent class '%s' to child class '%s': Cant use eager loading on a self referential relationship." % (self.key, repr(self.parent.class_), repr(self.mapper.class_)))
         self.parent._eager_loaders.add(self.parent_property)
 
@@ -337,8 +353,9 @@ class EagerLoader(AbstractRelationLoader):
         """
         def __init__(self, eagerloader, parentclauses=None):
             self.parent = eagerloader
-            self.target = eagerloader.target
-            self.eagertarget = eagerloader.target.alias()
+            self.target = eagerloader.select_table
+            self.eagertarget = eagerloader.select_table.alias()
+            
             if eagerloader.secondary:
                 self.eagersecondary = eagerloader.secondary.alias()
                 self.aliasizer = sql_util.Aliasizer(eagerloader.target, eagerloader.secondary, aliases={
@@ -346,12 +363,16 @@ class EagerLoader(AbstractRelationLoader):
                         eagerloader.secondary:self.eagersecondary
                         })
                 self.eagersecondaryjoin = eagerloader.secondaryjoin.copy_container()
+                if eagerloader.loads_polymorphic:
+                    self.eagersecondaryjoin.accept_visitor(sql_util.ClauseAdapter(eagerloader.select_table))
                 self.eagersecondaryjoin.accept_visitor(self.aliasizer)
                 self.eagerprimary = eagerloader.primaryjoin.copy_container()
                 self.eagerprimary.accept_visitor(self.aliasizer)
             else:
-                self.aliasizer = sql_util.Aliasizer(eagerloader.target, aliases={eagerloader.target:self.eagertarget})
                 self.eagerprimary = eagerloader.primaryjoin.copy_container()
+                if eagerloader.loads_polymorphic:
+                    self.eagerprimary.accept_visitor(sql_util.ClauseAdapter(eagerloader.select_table))
+                self.aliasizer = sql_util.Aliasizer(self.target, aliases={self.target:self.eagertarget})
                 self.eagerprimary.accept_visitor(self.aliasizer)
 
             if parentclauses is not None:
@@ -460,8 +481,8 @@ class EagerLoader(AbstractRelationLoader):
             clauses._aliasize_orderby(statement.order_by_clause, False)
                 
         statement.append_from(statement._outerjoin)
-        for value in self.mapper.props.values():
-            value.setup(context, eagertable=clauses.eagertarget, parentclauses=clauses, parentmapper=self.mapper)
+        for value in self.select_mapper.props.values():
+            value.setup(context, eagertable=clauses.eagertarget, parentclauses=clauses, parentmapper=self.select_mapper)
 
     def _create_row_processor(self, selectcontext, row):
         """create a 'row processing' function that will apply eager aliasing to the row.
index bfbcff5541331b726c9f58ae50a91406301f7083..4c6cd4d073df5f679d8905296440b6a0733d86ff 100644 (file)
@@ -107,3 +107,51 @@ class Aliasizer(sql.ClauseVisitor):
             binary.left = self.get_alias(binary.left.table).corresponding_column(binary.left)
         if isinstance(binary.right, schema.Column) and self.tables.has_key(binary.right.table):
             binary.right = self.get_alias(binary.right.table).corresponding_column(binary.right)
+
+
+class ClauseAdapter(sql.ClauseVisitor):
+    """given a clause (like as in a WHERE criterion), locates columns which 'correspond' to a given selectable, 
+    and changes those columns to be that of the selectable.
+    
+        such as:
+        
+        table1 = Table('sometable', metadata, 
+            Column('col1', Integer),
+            Column('col2', Integer)
+            )
+        table2 = Table('someothertable', metadata, 
+            Column('col1', Integer),
+            Column('col2', Integer)
+            )
+    
+        condition = table1.c.col1 == table2.c.col1
+        
+        and make an alias of table1:
+        
+        s = table1.alias('foo')
+        
+        calling condition.accept_visitor(ClauseAdapter(s)) converts condition to read:
+        
+        s.c.col1 == table2.c.col1
+    
+    """
+    def __init__(self, selectable):
+        self.selectable = selectable
+    def visit_binary(self, binary):
+        if isinstance(binary.left, sql.ColumnElement):
+            col = self.selectable.corresponding_column(binary.left, raiseerr=False, keys_ok=False)
+            if col is not None:
+                binary.left = col
+        if isinstance(binary.right, sql.ColumnElement):
+            col = self.selectable.corresponding_column(binary.right, raiseerr=False, keys_ok=False)
+            if col is not None:
+                binary.right = col
+
+class ColumnsInClause(sql.ClauseVisitor):
+    """given a selectable, visits clauses and determines if any columns from the clause are in the selectable"""
+    def __init__(self, selectable):
+        self.selectable = selectable
+        self.result = False
+    def visit_column(self, column):
+        if self.selectable.c.get(column.key) is column:
+            self.result = True
index e54bcb8c83d0d000542df91892914405ae9841a7..7d59c5d547709a561cbfca98ffb4b7cc7fb465c6 100644 (file)
@@ -99,7 +99,7 @@ class RelationTest2(testbase.AssertMixin):
         for t in metadata.table_iterator(reverse=True):
             t.delete().execute()
 
-    def testbasic(self):
+    def testrelationonsubclass(self):
         class Person(AttrSettable):
             pass
         class Manager(Person):
@@ -115,7 +115,7 @@ class RelationTest2(testbase.AssertMixin):
               properties={
                 'colleague':relation(Person, primaryjoin=managers.c.manager_id==people.c.person_id, uselist=False)
         })
-        
+        class_mapper(Person).compile()
         sess = create_session()
         p = Person(name='person1')
         m = Manager(name='manager1')
@@ -130,6 +130,69 @@ class RelationTest2(testbase.AssertMixin):
         print m
         assert m.colleague is p
 
+class RelationTest3(testbase.AssertMixin):
+    """test self-referential relationships on polymorphic mappers"""
+    def setUpAll(self):
+        global people, managers, metadata
+        metadata = BoundMetaData(testbase.db)
+
+        people = Table('people', metadata, 
+           Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+           Column('colleague_id', Integer, ForeignKey('people.person_id')),
+           Column('name', String(50)),
+           Column('type', String(30)))
+
+        managers = Table('managers', metadata, 
+           Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True),
+           Column('status', String(30)),
+           )
+
+        metadata.create_all()
+
+    def tearDownAll(self):
+        metadata.drop_all()
+
+    def tearDown(self):
+        clear_mappers()
+        for t in metadata.table_iterator(reverse=True):
+            t.delete().execute()
+
+    def testrelationonbaseclass(self):
+        class Person(AttrSettable):
+            pass
+        class Manager(Person):
+            pass
+
+        poly_union = polymorphic_union({
+            'manager':managers.join(people, people.c.person_id==managers.c.person_id),
+            'person':people.select(people.c.type=='person')
+        }, None)
+
+        mapper(Person, people, select_table=poly_union, polymorphic_identity='person', polymorphic_on=people.c.type,
+              properties={
+                'colleagues':relation(Person, primaryjoin=people.c.colleague_id==people.c.person_id, 
+                    remote_side=people.c.person_id, uselist=True)
+                }        
+        )
+        mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager')
+
+        sess = create_session()
+        p = Person(name='person1')
+        p2 = Person(name='person2')
+        m = Manager(name='manager1')
+        p.colleagues.append(p2)
+        m.colleagues.append(p2)
+        sess.save(m)
+        sess.save(p)
+        sess.flush()
+        
+        sess.clear()
+        p = sess.query(Person).get(p.person_id)
+        p2 = sess.query(Person).get(p2.person_id)
+        print p, p2, p.colleagues
+        assert len(p.colleagues) == 1
+        assert p.colleagues == [p2]
+
 if __name__ == "__main__":    
     testbase.main()
         
\ No newline at end of file
index e3e10156ee754ecf4166c770344eef10d8d6703f..adf844c5c6ab4e6ec40ac900ac4891f9cff72ac6 100644 (file)
@@ -62,6 +62,25 @@ class PolymorphicCircularTest(testbase.PersistTest):
             def __repr__(self):
                 return "%s(%d, %s)" % (self.__class__.__name__, self.id, repr(str(self.data)))
             
+        try:
+            # this is how the mapping used to work.  insure that this raises an error now
+            table1_mapper = mapper(Table1, table1,
+                                   select_table=join,
+                                   polymorphic_on=join.c.type,
+                                   polymorphic_identity='table1',
+                                   properties={
+                                    'next': relation(Table1, 
+                                        backref=backref('prev', primaryjoin=join.c.id==join.c.related_id, foreignkey=join.c.id, uselist=False), 
+                                        uselist=False, lazy=False, primaryjoin=join.c.id==join.c.related_id),
+                                    'data':relation(mapper(Data, data), lazy=False)
+                                    }
+                            )
+            table1_mapper.compile()
+            assert False
+        except:
+            assert True
+            clear_mappers()
+            
         # currently, all of these "eager" relationships degrade to lazy relationships
         # due to the polymorphic load.
         table1_mapper = mapper(Table1, table1,
@@ -69,12 +88,14 @@ class PolymorphicCircularTest(testbase.PersistTest):
                                polymorphic_on=join.c.type,
                                polymorphic_identity='table1',
                                properties={
-                                'next': relation(Table1, 
-                                    backref=backref('prev', primaryjoin=join.c.id==join.c.related_id, foreignkey=join.c.id, uselist=False), 
-                                    uselist=False, lazy=False, primaryjoin=join.c.id==join.c.related_id),
-                                'data':relation(mapper(Data, data), lazy=False)
+                               'next': relation(Table1, 
+                                   backref=backref('prev', primaryjoin=table1.c.id==table1.c.related_id, remote_side=table1.c.id, uselist=False), 
+                                   uselist=False, lazy=False, primaryjoin=table1.c.id==table1.c.related_id),
+                               'data':relation(mapper(Data, data), lazy=False)
                                 }
                         )
+        
+
 
         table1b_mapper = mapper(Table1B, inherits=table1_mapper, polymorphic_identity='table1b')
 
index 5a3c1292c6cabbd8b8c4160f1ded4976f1065ae1..5cac1d0ab160efffe53d66d72a7b8fe22c3d1430 100644 (file)
@@ -84,6 +84,22 @@ class MultipleTableTest(testbase.PersistTest):
         self.do_test(True, True, False)
     def test_t_t_t(self):
         self.do_test(True, True, True)
+    def test_f_f_f_t(self):
+        self.do_test(False, False, False, True)
+    def test_f_f_t_t(self):
+        self.do_test(False, False, True, True)
+    def test_f_t_f_t(self):
+        self.do_test(False, True, False, True)
+    def test_f_t_t_t(self):
+        self.do_test(False, True, True, True)
+    def test_t_f_f_t(self):
+        self.do_test(True, False, False, True)
+    def test_t_f_t_t(self):
+        self.do_test(True, False, True, True)
+    def test_t_t_f_t(self):
+        self.do_test(True, True, False, True)
+    def test_t_t_t_t(self):
+        self.do_test(True, True, True, True)
         
     def testcompile(self):
         person_join = polymorphic_union( {
@@ -147,7 +163,7 @@ class MultipleTableTest(testbase.PersistTest):
         except exceptions.ArgumentError:
             assert True
         
-    def do_test(self, include_base=False, lazy_relation=True, redefine_colprop=False):
+    def do_test(self, include_base=False, lazy_relation=True, redefine_colprop=False, use_literal_join=False):
         """tests the polymorph.py example, with several options:
         
         include_base - whether or not to include the base 'person' type in the union.
@@ -177,10 +193,19 @@ class MultipleTableTest(testbase.PersistTest):
         mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer')
         mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager')
 
-        mapper(Company, companies, properties={
-            'employees': relation(Person, lazy=lazy_relation, private=True, backref='company')
-        })
-
+        if use_literal_join:
+            mapper(Company, companies, properties={
+                'employees': relation(Person, lazy=lazy_relation, primaryjoin=people.c.company_id==companies.c.company_id, private=True, 
+                backref="company"
+                )
+            })
+        else:
+            mapper(Company, companies, properties={
+                'employees': relation(Person, lazy=lazy_relation, private=True, 
+                backref="company"
+                )
+            })
+            
         if redefine_colprop:
             person_attribute_name = 'person_name'
         else: