"""
props = self._props
- tables = set(props[key].parent.local_table for key in attribute_names)
+
+ tables = set(chain(*
+ (sqlutil.find_tables(props[key].columns[0], check_columns=True)
+ for key in attribute_names)
+ ))
+
if self.base_mapper.local_table in tables:
return None
return None
cond = sql.and_(*allconds)
- return sql.select(tables, cond, use_labels=True)
+
+ return sql.select([props[key].columns[0] for key in attribute_names], cond, use_labels=True)
def cascade_iterator(self, type_, state, halt_on=None):
"""Iterate each element and its mapper in an object graph,
tables = []
_visitors = {}
- def visit_something(elem):
- tables.append(elem)
-
if include_selects:
- _visitors['select'] = _visitors['compound_select'] = visit_something
+ _visitors['select'] = _visitors['compound_select'] = tables.append
if include_joins:
- _visitors['join'] = visit_something
+ _visitors['join'] = tables.append
if include_aliases:
- _visitors['alias'] = visit_something
+ _visitors['alias'] = tables.append
if check_columns:
def visit_column(column):
tables.append(column.table)
_visitors['column'] = visit_column
- _visitors['table'] = visit_something
+ _visitors['table'] = tables.append
visitors.traverse(clause, {'column_collections':False}, _visitors)
return tables
assert sess.query(Sub).get(s1.base_id).data == "this is base"
class OptimizedLoadTest(_base.MappedTest):
- """test that the 'optimized load' routine doesn't crash when
- a column in the join condition is not available.
+ """tests for the "optimized load" routine."""
- """
@classmethod
def define_tables(cls, metadata):
global base, sub
)
def test_optimized_passes(self):
- class Base(object):
+ """"test that the 'optimized load' routine doesn't crash when
+ a column in the join condition is not available."""
+
+ class Base(_base.BasicEntity):
pass
class Sub(Base):
pass
# redefine Sub's "id" to favor the "id" col in the subtable.
# "id" is also part of the primary join condition
mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={'id':sub.c.id})
- sess = create_session()
- s1 = Sub()
- s1.data = 's1data'
- s1.sub = 's1sub'
+ sess = sessionmaker()()
+ s1 = Sub(data='s1data', sub='s1sub')
sess.add(s1)
- sess.flush()
+ sess.commit()
sess.expunge_all()
# load s1 via Base. s1.id won't populate since it's relative to
# the "sub" table. The optimized load kicks in and tries to
# generate on the primary join, but cannot since "id" is itself unloaded.
# the optimized load needs to return "None" so regular full-row loading proceeds
- s1 = sess.query(Base).get(s1.id)
+ s1 = sess.query(Base).first()
assert s1.sub == 's1sub'
+ def test_column_expression(self):
+ class Base(_base.BasicEntity):
+ pass
+ class Sub(Base):
+ pass
+ mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base')
+ mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={
+ 'concat':column_property(sub.c.sub + "|" + sub.c.sub)
+ })
+ sess = sessionmaker()()
+ s1 = Sub(data='s1data', sub='s1sub')
+ sess.add(s1)
+ sess.commit()
+ sess.expunge_all()
+ s1 = sess.query(Base).first()
+ assert s1.concat == 's1sub|s1sub'
+
+ def test_column_expression_joined(self):
+ class Base(_base.ComparableEntity):
+ pass
+ class Sub(Base):
+ pass
+ mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base')
+ mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={
+ 'concat':column_property(base.c.data + "|" + sub.c.sub)
+ })
+ sess = sessionmaker()()
+ s1 = Sub(data='s1data', sub='s1sub')
+ s2 = Sub(data='s2data', sub='s2sub')
+ s3 = Sub(data='s3data', sub='s3sub')
+ sess.add_all([s1, s2, s3])
+ sess.commit()
+ sess.expunge_all()
+ # query a bunch of rows to ensure there's no cartesian
+ # product against "base" occurring, it is in fact
+ # detecting that "base" needs to be in the join
+ # criterion
+ eq_(
+ sess.query(Base).order_by(Base.id).all(),
+ [
+ Sub(data='s1data', sub='s1sub', concat='s1data|s1sub'),
+ Sub(data='s2data', sub='s2sub', concat='s2data|s2sub'),
+ Sub(data='s3data', sub='s3sub', concat='s3data|s3sub')
+ ]
+ )
+
+
class PKDiscriminatorTest(_base.MappedTest):
@classmethod
def define_tables(cls, metadata):