# right set
if hasattr(self, '_cols_by_table') and col.table in self._cols_by_table and col not in self._cols_by_table[col.table]:
self._cols_by_table[col.table].add(col)
-
+
+ # if this ColumnProperty represents the "polymorphic discriminator"
+ # column, mark it. We'll need this when rendering columns
+ # in SELECT statements.
+ if not hasattr(prop, '_is_polymorphic_discriminator'):
+ prop._is_polymorphic_discriminator = (col is self.polymorphic_on or prop.columns[0] is self.polymorphic_on)
+
self.columns[key] = col
for col in prop.columns:
for col in col.proxy_set:
else:
return mappers, self._selectable_from_mappers(mappers)
- @property
- def _default_polymorphic_properties(self):
- return util.unique_list(
- chain(*[list(mapper.iterate_properties) for mapper in [self] + self._with_polymorphic_mappers])
- )
-
def _iterate_polymorphic_properties(self, mappers=None):
+ """Return an iterator of MapperProperty objects which will render into a SELECT."""
+
if mappers is None:
- return iter(self._default_polymorphic_properties)
+ mappers = self._with_polymorphic_mappers
+
+ if not mappers:
+ for c in self.iterate_properties:
+ yield c
else:
- return iter(util.unique_list(
+ # in the polymorphic case, filter out discriminator columns
+ # from other mappers, as these are sometimes dependent on that
+ # mapper's polymorphic selectable (which we don't want rendered)
+ for c in util.unique_list(
chain(*[list(mapper.iterate_properties) for mapper in [self] + mappers])
- ))
-
+ ):
+ if getattr(c, '_is_polymorphic_discriminator', False) and \
+ (not self.polymorphic_on or c.columns[0] is not self.polymorphic_on):
+ continue
+ yield c
+
@property
def properties(self):
raise NotImplementedError("Public collection of MapperProperty objects is "
class ConcreteTest(ORMTest):
def define_tables(self, metadata):
- global managers_table, engineers_table, hackers_table, companies
+ global managers_table, engineers_table, hackers_table, companies, employees_table
companies = Table('companies', metadata,
Column('id', Integer, primary_key=True),
Column('name', String(50)))
+ employees_table = Table('employees', metadata,
+ Column('employee_id', Integer, primary_key=True),
+ Column('name', String(50)),
+ Column('company_id', Integer, ForeignKey('companies.id'))
+ )
+
managers_table = Table('managers', metadata,
Column('employee_id', Integer, primary_key=True),
Column('name', String(50)),
session.expire(manager, ['manager_data'])
self.assertEquals(manager.manager_data, "knows how to manage things")
- def test_multi_level(self):
+ def test_multi_level_no_base(self):
class Employee(object):
def __init__(self, name):
self.name = name
assert set([repr(x) for x in session.query(Employee).all()]) == set(["Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"])
assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"])
assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"])
+ assert set([repr(x) for x in session.query(Hacker).all()]) == set(["Hacker Kurt 'Badass' knows how to hack"])
+
+ def test_multi_level_with_base(self):
+ class Employee(object):
+ def __init__(self, name):
+ self.name = name
+ def __repr__(self):
+ return self.__class__.__name__ + " " + self.name
+
+ class Manager(Employee):
+ def __init__(self, name, manager_data):
+ self.name = name
+ self.manager_data = manager_data
+ def __repr__(self):
+ return self.__class__.__name__ + " " + self.name + " " + self.manager_data
+
+ class Engineer(Employee):
+ def __init__(self, name, engineer_info):
+ self.name = name
+ self.engineer_info = engineer_info
+ def __repr__(self):
+ return self.__class__.__name__ + " " + self.name + " " + self.engineer_info
+
+ class Hacker(Engineer):
+ def __init__(self, name, nickname, engineer_info):
+ self.name = name
+ self.nickname = nickname
+ self.engineer_info = engineer_info
+ def __repr__(self):
+ return self.__class__.__name__ + " " + self.name + " '" + \
+ self.nickname + "' " + self.engineer_info
+
+ pjoin = polymorphic_union({
+ 'employee':employees_table,
+ 'manager': managers_table,
+ 'engineer': engineers_table,
+ 'hacker': hackers_table
+ }, 'type', 'pjoin')
+
+ pjoin2 = polymorphic_union({
+ 'engineer': engineers_table,
+ 'hacker': hackers_table
+ }, 'type', 'pjoin2')
+
+ employee_mapper = mapper(Employee, employees_table, with_polymorphic=('*', pjoin), polymorphic_on=pjoin.c.type)
+ manager_mapper = mapper(Manager, managers_table,
+ inherits=employee_mapper, concrete=True,
+ polymorphic_identity='manager')
+ engineer_mapper = mapper(Engineer, engineers_table,
+ with_polymorphic=('*', pjoin2),
+ polymorphic_on=pjoin2.c.type,
+ inherits=employee_mapper, concrete=True,
+ polymorphic_identity='engineer')
+ hacker_mapper = mapper(Hacker, hackers_table,
+ inherits=engineer_mapper,
+ concrete=True, polymorphic_identity='hacker')
+
+ session = create_session()
+ tom = Manager('Tom', 'knows how to manage things')
+ jerry = Engineer('Jerry', 'knows how to program')
+ hacker = Hacker('Kurt', 'Badass', 'knows how to hack')
+ session.add_all((tom, jerry, hacker))
+ session.flush()
+
+ # ensure "readonly" on save logic didn't pollute the expired_attributes
+ # collection
+ assert 'nickname' not in attributes.instance_state(jerry).expired_attributes
+ assert 'name' not in attributes.instance_state(jerry).expired_attributes
+ assert 'name' not in attributes.instance_state(hacker).expired_attributes
+ assert 'nickname' not in attributes.instance_state(hacker).expired_attributes
+ def go():
+ self.assertEquals(jerry.name, "Jerry")
+ self.assertEquals(hacker.nickname, "Badass")
+ self.assert_sql_count(testing.db, go, 0)
+
+ session.clear()
+
+ # check that we aren't getting a cartesian product in the raw SQL.
+ # this requires that Engineer's polymorphic discriminator is not rendered
+ # in the statement which is only against Employee's "pjoin"
+ assert len(testing.db.execute(session.query(Employee).with_labels().statement).fetchall()) == 3
+
+ assert set([repr(x) for x in session.query(Employee).all()]) == set(["Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"])
+ assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"])
+ assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"])
+ assert set([repr(x) for x in session.query(Hacker).all()]) == set(["Hacker Kurt 'Badass' knows how to hack"])
def test_relation(self):
class Employee(object):
assert str(e) == "a_1.id = a.xxx_id"
+ def test_recursive_equivalents(self):
+ m = MetaData()
+ a = Table('a', m, Column('x', Integer), Column('y', Integer))
+ b = Table('b', m, Column('x', Integer), Column('y', Integer))
+ c = Table('c', m, Column('x', Integer), Column('y', Integer))
+
+ # force a recursion overflow, by linking a.c.x<->c.c.x, and
+ # asking for a nonexistent col. corresponding_column should prevent
+ # endless depth.
+ adapt = sql_util.ClauseAdapter( b, equivalents= {a.c.x: set([ c.c.x]), c.c.x:set([a.c.x])})
+ assert adapt._corresponding_column(a.c.x, False) is None
+
+ def test_multilevel_equivalents(self):
+ m = MetaData()
+ a = Table('a', m, Column('x', Integer), Column('y', Integer))
+ b = Table('b', m, Column('x', Integer), Column('y', Integer))
+ c = Table('c', m, Column('x', Integer), Column('y', Integer))
+
+ alias = select([a]).select_from(a.join(b, a.c.x==b.c.x)).alias()
+
+ # two levels of indirection from c.x->b.x->a.x, requires recursive
+ # corresponding_column call
+ adapt = sql_util.ClauseAdapter(alias, equivalents= {b.c.x: set([ a.c.x]), c.c.x:set([b.c.x])})
+ assert adapt._corresponding_column(a.c.x, False) is alias.c.x
+ assert adapt._corresponding_column(c.c.x, False) is alias.c.x
+
def test_join_to_alias(self):
metadata = MetaData()
a = Table('a', metadata,