to_selectable = target_mapper._with_polymorphic_selectable
if self.prop._is_self_referential():
to_selectable = to_selectable.alias()
+
+ single_crit = target_mapper._single_table_criterion
+ if single_crit:
+ if criterion is not None:
+ criterion = single_crit & criterion
+ else:
+ criterion = single_crit
else:
to_selectable = None
source_selectable = self.__clause_element__()
else:
source_selectable = None
+
pj, sj, source, dest, secondary, target_adapter = \
self.prop._create_joins(dest_polymorphic=True, dest_selectable=to_selectable, source_selectable=source_selectable)
def _is_self_referential(self):
return self.mapper.common_parent(self.parent)
- def _create_joins(self, source_polymorphic=False, source_selectable=None, dest_polymorphic=False, dest_selectable=None):
- key = util.WeakCompositeKey(source_polymorphic, source_selectable, dest_polymorphic, dest_selectable)
+ def _create_joins(self, source_polymorphic=False, source_selectable=None, dest_polymorphic=False, dest_selectable=None, of_type=None):
+ key = util.WeakCompositeKey(source_polymorphic, source_selectable, dest_polymorphic, dest_selectable, of_type)
try:
return self.__join_cache[key]
except KeyError:
# in the case that the join is to a subclass
# this is analgous to the "_adjust_for_single_table_inheritance()"
# method in Query.
- if self.mapper.single and self.mapper.inherits and self.mapper.polymorphic_on and self.mapper.polymorphic_identity is not None:
- crit = self.mapper.polymorphic_on.in_(
- m.polymorphic_identity
- for m in self.mapper.polymorphic_iterator())
+
+ dest_mapper = of_type or self.mapper
+
+ single_crit = dest_mapper._single_table_criterion
+ if single_crit:
if secondaryjoin:
- secondaryjoin = secondaryjoin & crit
+ secondaryjoin = secondaryjoin & single_crit
else:
- primaryjoin = primaryjoin & crit
+ primaryjoin = primaryjoin & single_crit
if aliased:
self.__currenttables.add(prop.secondary)
self.__currenttables.add(prop.table)
- right_entity = prop.mapper
+ if not right_entity:
+ right_entity = prop.mapper
if alias_criterion:
right_adapter = ORMAdapter(right_entity,
"""
for entity, (mapper, adapter, s, i, w) in self._mapper_adapter_map.iteritems():
- if mapper.single and mapper.inherits and mapper.polymorphic_on and mapper.polymorphic_identity is not None:
- crit = mapper.polymorphic_on.in_(
- m.polymorphic_identity
- for m in mapper.polymorphic_iterator())
+ single_crit = mapper._single_table_criterion
+ if single_crit:
if adapter:
- crit = adapter.traverse(crit)
- crit = self._adapt_clause(crit, False, False)
- context.whereclause = sql.and_(context.whereclause, crit)
+ single_crit = adapter.traverse(single_crit)
+ single_crit = self._adapt_clause(single_crit, False, False)
+ context.whereclause = sql.and_(context.whereclause, single_crit)
def __str__(self):
return str(self._compile_context().statement)
from testlib.fixtures import Base
from orm._base import MappedTest, ComparableEntity
+
class SingleInheritanceTest(MappedTest):
def define_tables(self, metadata):
- global employees_table
- employees_table = Table('employees', metadata,
+ Table('employees', metadata,
Column('employee_id', Integer, primary_key=True),
Column('name', String(50)),
Column('manager_data', String(50)),
Column('engineer_info', String(50)),
- Column('type', String(20))
+ Column('type', String(20)))
+
+ Table('reports', metadata,
+ Column('report_id', Integer, primary_key=True),
+ Column('employee_id', ForeignKey('employees.employee_id')),
+ Column('name', String(50)),
)
-
+
def setup_classes(self):
class Employee(ComparableEntity):
pass
@testing.resolve_artifact_names
def setup_mappers(self):
- mapper(Employee, employees_table, polymorphic_on=employees_table.c.type)
+ mapper(Employee, employees, polymorphic_on=employees.c.type)
mapper(Manager, inherits=Employee, polymorphic_identity='manager')
mapper(Engineer, inherits=Employee, polymorphic_identity='engineer')
mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer')
sess.flush()
self.assertEquals(
- sess.query(Manager).select_from(employees_table.select().limit(10)).all(),
+ sess.query(Manager).select_from(employees.select().limit(10)).all(),
[m1, m2]
)
self.assertEquals(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2)
self.assertEquals(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3)
+ @testing.resolve_artifact_names
+ def test_type_filtering(self):
+ class Report(ComparableEntity): pass
+
+ mapper(Report, reports, properties={
+ 'employee': relation(Employee, backref='reports')})
+ sess = create_session()
+
+ m1 = Manager(name='Tom', manager_data='data1')
+ r1 = Report(employee=m1)
+ sess.add_all([m1, r1])
+ sess.flush()
+ rq = sess.query(Report)
+
+ assert len(rq.filter(Report.employee.of_type(Manager).has()).all()) == 1
+ assert len(rq.filter(Report.employee.of_type(Engineer).has()).all()) == 0
+
+ @testing.resolve_artifact_names
+ def test_type_joins(self):
+ class Report(ComparableEntity): pass
+
+ mapper(Report, reports, properties={
+ 'employee': relation(Employee, backref='reports')})
+ sess = create_session()
+
+ m1 = Manager(name='Tom', manager_data='data1')
+ r1 = Report(employee=m1)
+ sess.add_all([m1, r1])
+ sess.flush()
+
+ rq = sess.query(Report)
+
+ assert len(rq.join(Report.employee.of_type(Manager)).all()) == 1
+ assert len(rq.join(Report.employee.of_type(Engineer)).all()) == 0
+
+
class RelationToSingleTest(MappedTest):
def define_tables(self, metadata):
Table('employees', metadata,
class JuniorEngineer(Engineer):
pass
+ @testing.resolve_artifact_names
+ def test_of_type(self):
+ mapper(Company, companies, properties={
+ 'employees':relation(Employee, backref='company')
+ })
+ mapper(Employee, employees, polymorphic_on=employees.c.type)
+ mapper(Manager, inherits=Employee, polymorphic_identity='manager')
+ mapper(Engineer, inherits=Employee, polymorphic_identity='engineer')
+ mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer')
+ sess = sessionmaker()()
+
+ c1 = Company(name='c1')
+ c2 = Company(name='c2')
+
+ m1 = Manager(name='Tom', manager_data='data1', company=c1)
+ m2 = Manager(name='Tom2', manager_data='data2', company=c2)
+ e1 = Engineer(name='Kurt', engineer_info='knows how to hack', company=c2)
+ e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed', company=c1)
+ sess.add_all([c1, c2, m1, m2, e1, e2])
+ sess.commit()
+ sess.clear()
+ self.assertEquals(
+ sess.query(Company).filter(Company.employees.of_type(JuniorEngineer).any()).all(),
+ [
+ Company(name='c1'),
+ ]
+ )
+
+ self.assertEquals(
+ sess.query(Company).join(Company.employees.of_type(JuniorEngineer)).all(),
+ [
+ Company(name='c1'),
+ ]
+ )
+
+
@testing.resolve_artifact_names
def test_relation_to_subclass(self):
mapper(Company, companies, properties={