]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed bug whereby a load/refresh of joined table
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Jul 2009 21:26:28 +0000 (21:26 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Jul 2009 21:26:28 +0000 (21:26 +0000)
inheritance attributes which were based on
column_property() or similar would fail to evaluate.
[ticket:1480]

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/sql/util.py
test/orm/inheritance/test_basic.py

diff --git a/CHANGES b/CHANGES
index b13600cad833e2e443fd626f09997d3e82f72b7e..38bdb15b8a9b76d21989b63c61b6eb98a8cb49c2 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -16,7 +16,12 @@ CHANGES
       during a flush.  This is currently to support 
       many-to-many relations from concrete inheritance setups.
       Outside of that use case, YMMV.  [ticket:1477]
-
+    
+    - Fixed bug whereby a load/refresh of joined table
+      inheritance attributes which were based on 
+      column_property() or similar would fail to evaluate.
+      [ticket:1480]
+      
     - Improved error message when query() is called with
       a non-SQL /entity expression. [ticket:1476]
     
index 9e939c918ad0ae78051f3be85ab3a1f5d19a631f..aac271efecfbfc4c19fb25967aa0d5073e42068b 100644 (file)
@@ -1101,7 +1101,12 @@ class Mapper(object):
         
         """
         props = self._props
-        tables = set(props[key].parent.local_table for key in attribute_names)
+        
+        tables = set(chain(*
+                        (sqlutil.find_tables(props[key].columns[0], check_columns=True) 
+                        for key in attribute_names)
+                    ))
+        
         if self.base_mapper.local_table in tables:
             return None
 
@@ -1138,7 +1143,8 @@ class Mapper(object):
             return None
 
         cond = sql.and_(*allconds)
-        return sql.select(tables, cond, use_labels=True)
+
+        return sql.select([props[key].columns[0] for key in attribute_names], cond, use_labels=True)
 
     def cascade_iterator(self, type_, state, halt_on=None):
         """Iterate each element and its mapper in an object graph,
index f1f329b5e27a31b60b52bd022b2ab69bedb903e2..ac95c3a20950a8d45eb66dce9a6322d8cde0171a 100644 (file)
@@ -53,24 +53,21 @@ def find_tables(clause, check_columns=False, include_aliases=False, include_join
     tables = []
     _visitors = {}
     
-    def visit_something(elem):
-        tables.append(elem)
-        
     if include_selects:
-        _visitors['select'] = _visitors['compound_select'] = visit_something
+        _visitors['select'] = _visitors['compound_select'] = tables.append
     
     if include_joins:
-        _visitors['join'] = visit_something
+        _visitors['join'] = tables.append
         
     if include_aliases:
-        _visitors['alias']  = visit_something
+        _visitors['alias']  = tables.append
 
     if check_columns:
         def visit_column(column):
             tables.append(column.table)
         _visitors['column'] = visit_column
 
-    _visitors['table'] = visit_something
+    _visitors['table'] = tables.append
 
     visitors.traverse(clause, {'column_collections':False}, _visitors)
     return tables
index 6aa77868ea2a8e2019455dde9cef253d21481006..bad6920de7ad2a16716097fe7fb92711144309cd 100644 (file)
@@ -915,10 +915,8 @@ class OverrideColKeyTest(_base.MappedTest):
         assert sess.query(Sub).get(s1.base_id).data == "this is base"
 
 class OptimizedLoadTest(_base.MappedTest):
-    """test that the 'optimized load' routine doesn't crash when 
-    a column in the join condition is not available.
+    """tests for the "optimized load" routine."""
     
-    """
     @classmethod
     def define_tables(cls, metadata):
         global base, sub
@@ -933,7 +931,10 @@ class OptimizedLoadTest(_base.MappedTest):
         )
     
     def test_optimized_passes(self):
-        class Base(object):
+        """"test that the 'optimized load' routine doesn't crash when 
+        a column in the join condition is not available."""
+        
+        class Base(_base.BasicEntity):
             pass
         class Sub(Base):
             pass
@@ -943,21 +944,66 @@ class OptimizedLoadTest(_base.MappedTest):
         # redefine Sub's "id" to favor the "id" col in the subtable.
         # "id" is also part of the primary join condition
         mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={'id':sub.c.id})
-        sess = create_session()
-        s1 = Sub()
-        s1.data = 's1data'
-        s1.sub = 's1sub'
+        sess = sessionmaker()()
+        s1 = Sub(data='s1data', sub='s1sub')
         sess.add(s1)
-        sess.flush()
+        sess.commit()
         sess.expunge_all()
         
         # load s1 via Base.  s1.id won't populate since it's relative to 
         # the "sub" table.  The optimized load kicks in and tries to 
         # generate on the primary join, but cannot since "id" is itself unloaded.
         # the optimized load needs to return "None" so regular full-row loading proceeds
-        s1 = sess.query(Base).get(s1.id)
+        s1 = sess.query(Base).first()
         assert s1.sub == 's1sub'
 
+    def test_column_expression(self):
+        class Base(_base.BasicEntity):
+            pass
+        class Sub(Base):
+            pass
+        mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base')
+        mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={
+            'concat':column_property(sub.c.sub + "|" + sub.c.sub)
+        })
+        sess = sessionmaker()()
+        s1 = Sub(data='s1data', sub='s1sub')
+        sess.add(s1)
+        sess.commit()
+        sess.expunge_all()
+        s1 = sess.query(Base).first()
+        assert s1.concat == 's1sub|s1sub'
+
+    def test_column_expression_joined(self):
+        class Base(_base.ComparableEntity):
+            pass
+        class Sub(Base):
+            pass
+        mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base')
+        mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={
+            'concat':column_property(base.c.data + "|" + sub.c.sub)
+        })
+        sess = sessionmaker()()
+        s1 = Sub(data='s1data', sub='s1sub')
+        s2 = Sub(data='s2data', sub='s2sub')
+        s3 = Sub(data='s3data', sub='s3sub')
+        sess.add_all([s1, s2, s3])
+        sess.commit()
+        sess.expunge_all()
+        # query a bunch of rows to ensure there's no cartesian
+        # product against "base" occurring, it is in fact
+        # detecting that "base" needs to be in the join 
+        # criterion
+        eq_(
+            sess.query(Base).order_by(Base.id).all(),
+            [
+                Sub(data='s1data', sub='s1sub', concat='s1data|s1sub'),
+                Sub(data='s2data', sub='s2sub', concat='s2data|s2sub'),
+                Sub(data='s3data', sub='s3sub', concat='s3data|s3sub')
+            ]
+        )
+        
+        
 class PKDiscriminatorTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):