]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Two fixes to help prevent out-of-band columns from
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 3 Dec 2008 17:28:36 +0000 (17:28 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 3 Dec 2008 17:28:36 +0000 (17:28 +0000)
being rendered in polymorphic_union inheritance
scenarios (which then causes extra tables to be
rendered in the FROM clause causing cartesian
products):
- improvements to "column adaption" for
  a->b->c inheritance situations to better
  locate columns that are related to one
  another via multiple levels of indirection,
  rather than rendering the non-adapted
  column.
- the "polymorphic discriminator" column is
  only rendered for the actual mapper being
  queried against. The column won't be
  "pulled in" from a subclass or superclass
  mapper since it's not needed.

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/sql/util.py
test/orm/inheritance/concrete.py
test/sql/generative.py

diff --git a/CHANGES b/CHANGES
index 562a5baa8ff992a3f9c8282535ab5091a60020d6..f1245c7bfafc64293a2d2d7e05f041bf1d50cace 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -61,6 +61,23 @@ CHANGES
     - made Session.merge cascades not trigger autoflush.
       Fixes merged instances getting prematurely inserted
       with missing values.
+
+    - Two fixes to help prevent out-of-band columns from
+      being rendered in polymorphic_union inheritance
+      scenarios (which then causes extra tables to be
+      rendered in the FROM clause causing cartesian 
+      products): 
+        - improvements to "column adaption" for
+          a->b->c inheritance situations to better
+          locate columns that are related to one
+          another via multiple levels of indirection,
+          rather than rendering the non-adapted
+          column.
+        - the "polymorphic discriminator" column is
+          only rendered for the actual mapper being
+          queried against. The column won't be
+          "pulled in" from a subclass or superclass
+          mapper since it's not needed.
       
 - sql
     - Fixed the import weirdness in sqlalchemy.sql
index b48297a6451516c9ecd976e91bda12041d744130..48c2f9e27f590b8e0a198689d0feac7419972cf8 100644 (file)
@@ -612,7 +612,13 @@ class Mapper(object):
                 # right set
                 if hasattr(self, '_cols_by_table') and col.table in self._cols_by_table and col not in self._cols_by_table[col.table]:
                     self._cols_by_table[col.table].add(col)
-
+            
+            # if this ColumnProperty represents the "polymorphic discriminator"
+            # column, mark it.  We'll need this when rendering columns
+            # in SELECT statements.
+            if not hasattr(prop, '_is_polymorphic_discriminator'):
+                prop._is_polymorphic_discriminator = (col is self.polymorphic_on or prop.columns[0] is self.polymorphic_on)
+                
             self.columns[key] = col
             for col in prop.columns:
                 for col in col.proxy_set:
@@ -860,20 +866,27 @@ class Mapper(object):
         else:
             return mappers, self._selectable_from_mappers(mappers)
 
-    @property
-    def _default_polymorphic_properties(self):
-        return util.unique_list(
-            chain(*[list(mapper.iterate_properties) for mapper in [self] + self._with_polymorphic_mappers])
-        )
-        
     def _iterate_polymorphic_properties(self, mappers=None):
+        """Return an iterator of MapperProperty objects which will render into a SELECT."""
+        
         if mappers is None:
-            return iter(self._default_polymorphic_properties)
+            mappers = self._with_polymorphic_mappers
+
+        if not mappers:
+            for c in self.iterate_properties:
+                yield c
         else:
-            return iter(util.unique_list(
+            # in the polymorphic case, filter out discriminator columns
+            # from other mappers, as these are sometimes dependent on that
+            # mapper's polymorphic selectable (which we don't want rendered)
+            for c in util.unique_list(
                 chain(*[list(mapper.iterate_properties) for mapper in [self] + mappers])
-            ))
-
+            ):
+                if getattr(c, '_is_polymorphic_discriminator', False) and \
+                    (not self.polymorphic_on or c.columns[0] is not self.polymorphic_on):
+                        continue
+                yield c
+    
     @property
     def properties(self):
         raise NotImplementedError("Public collection of MapperProperty objects is "
index d5f2417c27bce1b4756b94766e14db0b3dbf59cf..1bd4dd857d77d6a443a190edb552f77d80c65edf 100644 (file)
@@ -439,12 +439,12 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
         self.exclude = exclude
         self.equivalents = equivalents or {}
         
-    def _corresponding_column(self, col, require_embedded):
+    def _corresponding_column(self, col, require_embedded, _seen=util.EMPTY_SET):
         newcol = self.selectable.corresponding_column(col, require_embedded=require_embedded)
 
-        if not newcol and col in self.equivalents:
+        if not newcol and col in self.equivalents and col not in _seen:
             for equiv in self.equivalents[col]:
-                newcol = self.selectable.corresponding_column(equiv, require_embedded=require_embedded)
+                newcol = self._corresponding_column(equiv, require_embedded=require_embedded, _seen=_seen.union([col]))
                 if newcol:
                     return newcol
         return newcol
index b8d799b60482d6c909e9a10964f31bae3d41409f..e6277f3e919978138a5fd2d86863673d42acabc3 100644 (file)
@@ -7,12 +7,18 @@ from sqlalchemy.orm import attributes
 
 class ConcreteTest(ORMTest):
     def define_tables(self, metadata):
-        global managers_table, engineers_table, hackers_table, companies
+        global managers_table, engineers_table, hackers_table, companies, employees_table
 
         companies = Table('companies', metadata,
            Column('id', Integer, primary_key=True),
            Column('name', String(50)))
 
+        employees_table = Table('employees', metadata,
+            Column('employee_id', Integer, primary_key=True),
+            Column('name', String(50)),
+            Column('company_id', Integer, ForeignKey('companies.id'))
+        )
+        
         managers_table = Table('managers', metadata,
             Column('employee_id', Integer, primary_key=True),
             Column('name', String(50)),
@@ -80,7 +86,7 @@ class ConcreteTest(ORMTest):
         session.expire(manager, ['manager_data'])
         self.assertEquals(manager.manager_data, "knows how to manage things")
 
-    def test_multi_level(self):
+    def test_multi_level_no_base(self):
         class Employee(object):
             def __init__(self, name):
                 self.name = name
@@ -157,6 +163,92 @@ class ConcreteTest(ORMTest):
         assert set([repr(x) for x in session.query(Employee).all()]) == set(["Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"])
         assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"])
         assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"])
+        assert set([repr(x) for x in session.query(Hacker).all()]) == set(["Hacker Kurt 'Badass' knows how to hack"])
+
+    def test_multi_level_with_base(self):
+        class Employee(object):
+            def __init__(self, name):
+                self.name = name
+            def __repr__(self):
+                return self.__class__.__name__ + " " + self.name
+
+        class Manager(Employee):
+            def __init__(self, name, manager_data):
+                self.name = name
+                self.manager_data = manager_data
+            def __repr__(self):
+                return self.__class__.__name__ + " " + self.name + " " +  self.manager_data
+
+        class Engineer(Employee):
+            def __init__(self, name, engineer_info):
+                self.name = name
+                self.engineer_info = engineer_info
+            def __repr__(self):
+                return self.__class__.__name__ + " " + self.name + " " +  self.engineer_info
+
+        class Hacker(Engineer):
+            def __init__(self, name, nickname, engineer_info):
+                self.name = name
+                self.nickname = nickname
+                self.engineer_info = engineer_info
+            def __repr__(self):
+                return self.__class__.__name__ + " " + self.name + " '" + \
+                       self.nickname + "' " +  self.engineer_info
+
+        pjoin = polymorphic_union({
+            'employee':employees_table,
+            'manager': managers_table,
+            'engineer': engineers_table,
+            'hacker': hackers_table
+        }, 'type', 'pjoin')
+
+        pjoin2 = polymorphic_union({
+            'engineer': engineers_table,
+            'hacker': hackers_table
+        }, 'type', 'pjoin2')
+
+        employee_mapper = mapper(Employee, employees_table, with_polymorphic=('*', pjoin), polymorphic_on=pjoin.c.type)
+        manager_mapper = mapper(Manager, managers_table, 
+                                inherits=employee_mapper, concrete=True, 
+                                polymorphic_identity='manager')
+        engineer_mapper = mapper(Engineer, engineers_table, 
+                                 with_polymorphic=('*', pjoin2), 
+                                 polymorphic_on=pjoin2.c.type,
+                                 inherits=employee_mapper, concrete=True,
+                                 polymorphic_identity='engineer')
+        hacker_mapper = mapper(Hacker, hackers_table, 
+                               inherits=engineer_mapper,
+                               concrete=True, polymorphic_identity='hacker')
+
+        session = create_session()
+        tom = Manager('Tom', 'knows how to manage things')
+        jerry = Engineer('Jerry', 'knows how to program')
+        hacker = Hacker('Kurt', 'Badass', 'knows how to hack')
+        session.add_all((tom, jerry, hacker))
+        session.flush()
+
+        # ensure "readonly" on save logic didn't pollute the expired_attributes
+        # collection
+        assert 'nickname' not in attributes.instance_state(jerry).expired_attributes
+        assert 'name' not in attributes.instance_state(jerry).expired_attributes
+        assert 'name' not in attributes.instance_state(hacker).expired_attributes
+        assert 'nickname' not in attributes.instance_state(hacker).expired_attributes
+        def go():
+            self.assertEquals(jerry.name, "Jerry")
+            self.assertEquals(hacker.nickname, "Badass")
+        self.assert_sql_count(testing.db, go, 0)
+
+        session.clear()
+
+        # check that we aren't getting a cartesian product in the raw SQL.
+        # this requires that Engineer's polymorphic discriminator is not rendered
+        # in the statement which is only against Employee's "pjoin"
+        assert len(testing.db.execute(session.query(Employee).with_labels().statement).fetchall()) == 3
+        
+        assert set([repr(x) for x in session.query(Employee).all()]) == set(["Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"])
+        assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"])
+        assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"])
+        assert set([repr(x) for x in session.query(Hacker).all()]) == set(["Hacker Kurt 'Badass' knows how to hack"])
 
     def test_relation(self):
         class Employee(object):
index f6b849e8a35b9384dd5af28b0836f704eb240a8f..4edf334f667db5220baf54047e45712fc60db91f 100644 (file)
@@ -458,6 +458,32 @@ class ClauseAdapterTest(TestBase, AssertsCompiledSQL):
 
         assert str(e) == "a_1.id = a.xxx_id"
 
+    def test_recursive_equivalents(self):
+        m = MetaData()
+        a = Table('a', m, Column('x', Integer), Column('y', Integer))
+        b = Table('b', m, Column('x', Integer), Column('y', Integer))
+        c = Table('c', m, Column('x', Integer), Column('y', Integer))
+        
+        # force a recursion overflow, by linking a.c.x<->c.c.x, and
+        # asking for a nonexistent col.  corresponding_column should prevent
+        # endless depth.
+        adapt = sql_util.ClauseAdapter( b, equivalents= {a.c.x: set([ c.c.x]), c.c.x:set([a.c.x])})
+        assert adapt._corresponding_column(a.c.x, False) is None
+
+    def test_multilevel_equivalents(self):
+        m = MetaData()
+        a = Table('a', m, Column('x', Integer), Column('y', Integer))
+        b = Table('b', m, Column('x', Integer), Column('y', Integer))
+        c = Table('c', m, Column('x', Integer), Column('y', Integer))
+
+        alias = select([a]).select_from(a.join(b, a.c.x==b.c.x)).alias()
+        
+        # two levels of indirection from c.x->b.x->a.x, requires recursive 
+        # corresponding_column call
+        adapt = sql_util.ClauseAdapter(alias, equivalents= {b.c.x: set([ a.c.x]), c.c.x:set([b.c.x])})
+        assert adapt._corresponding_column(a.c.x, False) is alias.c.x
+        assert adapt._corresponding_column(c.c.x, False) is alias.c.x
+        
     def test_join_to_alias(self):
         metadata = MetaData()
         a = Table('a', metadata,