]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
scaled back the equivalents determined in _equivalent_columns to just current polymor...
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 12 May 2008 22:44:04 +0000 (22:44 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 12 May 2008 22:44:04 +0000 (22:44 +0000)
 behavior, fixes [ticket:1041]

lib/sqlalchemy/orm/mapper.py
test/orm/query.py

index 5569f9216b63ac04488baa81d2f7d96a4ba6f47d..0ff1a05a0f0d8f462978b23a27e5158db1df5d1b 100644 (file)
@@ -560,8 +560,11 @@ class Mapper(object):
     def _equivalent_columns(self):
         """Create a map of all *equivalent* columns, based on
         the determination of column pairs that are equated to
-        one another either by an established foreign key relationship
-        or by a joined-table inheritance join.
+        one another based on inherit condition.  This is designed
+        to work with the queries that util.polymorphic_union 
+        comes up with, which often don't include the columns from
+        the base table directly (including the subclass table columns 
+        only).
 
         The resulting structure is a dictionary of columns mapped
         to lists of equivalent columns, i.e.
@@ -590,30 +593,6 @@ class Mapper(object):
             if mapper.inherit_condition:
                 visitors.traverse(mapper.inherit_condition, {}, {'binary':visit_binary})
 
-        # TODO: matching of cols to foreign keys might better be generalized
-        # into general column translation (i.e. corresponding_column)
-
-        # recursively descend into the foreign key collection of the given column
-        # and assemble each FK-related col as an "equivalent" for the given column
-        def equivs(col, recursive, equiv):
-            if col in recursive:
-                return
-            recursive.add(col)
-            for fk in col.foreign_keys:
-                if fk.column not in result:
-                    result[fk.column] = util.Set()
-                result[fk.column].add(equiv)
-                equivs(fk.column, recursive, col)
-
-        for column in (self.primary_key_argument or self._pks_by_table[self.mapped_table]):
-            for col in column.proxy_set:
-                if not col.foreign_keys:
-                    if col not in result:
-                        result[col] = util.Set()
-                    result[col].add(col)
-                else:
-                    equivs(col, util.Set(), col)
-
         return result
     _equivalent_columns = property(util.cache_decorator(_equivalent_columns))
 
index 46a3cbacd250233134ce57dc91b51b8ee71224d9..99d89d6696a01a1107974e4859812fbfa2ae253a 100644 (file)
@@ -7,6 +7,7 @@ from sqlalchemy.engine import default
 from sqlalchemy.orm import *
 
 from testlib import *
+from orm import _base
 from testlib import engines
 from testlib.fixtures import *
 
@@ -1985,6 +1986,55 @@ class ExternalColumnsTest(QueryTest):
             self.assertEquals(o1.address.user.count, 1)
         self.assert_sql_count(testing.db, go, 1)
 
+class TestOverlyEagerEquivalentCols(_base.MappedTest):
+    def define_tables(self, metadata):
+        global base, sub1, sub2
+        base = Table('base', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('data', String(50))
+        )
+
+        sub1 = Table('sub1', metadata, 
+            Column('id', Integer, ForeignKey('base.id'), primary_key=True),
+            Column('data', String(50))
+        )
+
+        sub2 = Table('sub2', metadata, 
+            Column('id', Integer, ForeignKey('base.id'), ForeignKey('sub1.id'), primary_key=True),
+            Column('data', String(50))
+        )
+    
+    def test_equivs(self):
+        class Base(_base.ComparableEntity):
+            pass
+        class Sub1(_base.ComparableEntity):
+            pass
+        class Sub2(_base.ComparableEntity):
+            pass
+        
+        mapper(Base, base, properties={
+            'sub1':relation(Sub1),
+            'sub2':relation(Sub2)
+        })
+        
+        mapper(Sub1, sub1)
+        mapper(Sub2, sub2)
+        sess = create_session()
+        b1 = Base(data='b1', sub1=[Sub1(data='s11')], sub2=[])
+        b2 = Base(data='b1', sub1=[Sub1(data='s12')], sub2=[Sub2(data='s2')])
+        sess.add(b1)
+        sess.add(b2)
+        sess.flush()
+        
+        q = sess.query(Base).outerjoin('sub2', aliased=True)
+        assert sub1.c.id not in q._filter_aliases.equivalents
+
+        self.assertEquals(
+            sess.query(Base).join('sub1').outerjoin('sub2', aliased=True).\
+                filter(Sub1.id==1).one(),
+                b1
+        )
+        
 
 if __name__ == '__main__':
     testenv.main()