From: Mike Bayer Date: Sat, 25 Jul 2009 21:26:28 +0000 (+0000) Subject: - Fixed bug whereby a load/refresh of joined table X-Git-Tag: rel_0_5_6~28 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=c1a36dfe4142cf630d0d3f4056fae43902cbcf6b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Fixed bug whereby a load/refresh of joined table inheritance attributes which were based on column_property() or similar would fail to evaluate. [ticket:1480] --- diff --git a/CHANGES b/CHANGES index b13600cad8..38bdb15b8a 100644 --- a/CHANGES +++ b/CHANGES @@ -16,7 +16,12 @@ CHANGES during a flush. This is currently to support many-to-many relations from concrete inheritance setups. Outside of that use case, YMMV. [ticket:1477] - + + - Fixed bug whereby a load/refresh of joined table + inheritance attributes which were based on + column_property() or similar would fail to evaluate. + [ticket:1480] + - Improved error message when query() is called with a non-SQL /entity expression. [ticket:1476] diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 9e939c918a..aac271efec 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1101,7 +1101,12 @@ class Mapper(object): """ props = self._props - tables = set(props[key].parent.local_table for key in attribute_names) + + tables = set(chain(* + (sqlutil.find_tables(props[key].columns[0], check_columns=True) + for key in attribute_names) + )) + if self.base_mapper.local_table in tables: return None @@ -1138,7 +1143,8 @@ class Mapper(object): return None cond = sql.and_(*allconds) - return sql.select(tables, cond, use_labels=True) + + return sql.select([props[key].columns[0] for key in attribute_names], cond, use_labels=True) def cascade_iterator(self, type_, state, halt_on=None): """Iterate each element and its mapper in an object graph, diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index f1f329b5e2..ac95c3a209 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -53,24 +53,21 @@ def find_tables(clause, check_columns=False, include_aliases=False, include_join tables = [] _visitors = {} - def visit_something(elem): - tables.append(elem) - if include_selects: - _visitors['select'] = _visitors['compound_select'] = visit_something + _visitors['select'] = _visitors['compound_select'] = tables.append if include_joins: - _visitors['join'] = visit_something + _visitors['join'] = tables.append if include_aliases: - _visitors['alias'] = visit_something + _visitors['alias'] = tables.append if check_columns: def visit_column(column): tables.append(column.table) _visitors['column'] = visit_column - _visitors['table'] = visit_something + _visitors['table'] = tables.append visitors.traverse(clause, {'column_collections':False}, _visitors) return tables diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 6aa77868ea..bad6920de7 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -915,10 +915,8 @@ class OverrideColKeyTest(_base.MappedTest): assert sess.query(Sub).get(s1.base_id).data == "this is base" class OptimizedLoadTest(_base.MappedTest): - """test that the 'optimized load' routine doesn't crash when - a column in the join condition is not available. + """tests for the "optimized load" routine.""" - """ @classmethod def define_tables(cls, metadata): global base, sub @@ -933,7 +931,10 @@ class OptimizedLoadTest(_base.MappedTest): ) def test_optimized_passes(self): - class Base(object): + """"test that the 'optimized load' routine doesn't crash when + a column in the join condition is not available.""" + + class Base(_base.BasicEntity): pass class Sub(Base): pass @@ -943,21 +944,66 @@ class OptimizedLoadTest(_base.MappedTest): # redefine Sub's "id" to favor the "id" col in the subtable. # "id" is also part of the primary join condition mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={'id':sub.c.id}) - sess = create_session() - s1 = Sub() - s1.data = 's1data' - s1.sub = 's1sub' + sess = sessionmaker()() + s1 = Sub(data='s1data', sub='s1sub') sess.add(s1) - sess.flush() + sess.commit() sess.expunge_all() # load s1 via Base. s1.id won't populate since it's relative to # the "sub" table. The optimized load kicks in and tries to # generate on the primary join, but cannot since "id" is itself unloaded. # the optimized load needs to return "None" so regular full-row loading proceeds - s1 = sess.query(Base).get(s1.id) + s1 = sess.query(Base).first() assert s1.sub == 's1sub' + def test_column_expression(self): + class Base(_base.BasicEntity): + pass + class Sub(Base): + pass + mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base') + mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={ + 'concat':column_property(sub.c.sub + "|" + sub.c.sub) + }) + sess = sessionmaker()() + s1 = Sub(data='s1data', sub='s1sub') + sess.add(s1) + sess.commit() + sess.expunge_all() + s1 = sess.query(Base).first() + assert s1.concat == 's1sub|s1sub' + + def test_column_expression_joined(self): + class Base(_base.ComparableEntity): + pass + class Sub(Base): + pass + mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base') + mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={ + 'concat':column_property(base.c.data + "|" + sub.c.sub) + }) + sess = sessionmaker()() + s1 = Sub(data='s1data', sub='s1sub') + s2 = Sub(data='s2data', sub='s2sub') + s3 = Sub(data='s3data', sub='s3sub') + sess.add_all([s1, s2, s3]) + sess.commit() + sess.expunge_all() + # query a bunch of rows to ensure there's no cartesian + # product against "base" occurring, it is in fact + # detecting that "base" needs to be in the join + # criterion + eq_( + sess.query(Base).order_by(Base.id).all(), + [ + Sub(data='s1data', sub='s1sub', concat='s1data|s1sub'), + Sub(data='s2data', sub='s2sub', concat='s2data|s2sub'), + Sub(data='s3data', sub='s3sub', concat='s3data|s3sub') + ] + ) + + class PKDiscriminatorTest(_base.MappedTest): @classmethod def define_tables(cls, metadata):