def visit_binary(binary):
leftcol = binary.left
rightcol = binary.right
+
if leftcol is None or rightcol is None:
return
if leftcol.table not in needs_tables:
binary.left = sql.bindparam(None, None, type_=binary.right.type)
param_names.append((leftcol, binary.left))
- elif rightcol not in needs_tables:
+ elif rightcol.table not in needs_tables:
binary.right = sql.bindparam(None, None, type_=binary.right.type)
param_names.append((rightcol, binary.right))
break
if not mapper.single:
allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
-
+
return sql.and_(*allconds), param_names
Mapper.logger = logging.class_logger(Mapper)
test_get_polymorphic = create_test(True, 'test_get_polymorphic')
test_get_nonpolymorphic = create_test(False, 'test_get_nonpolymorphic')
+class InheritConditionTest(ORMTest):
+ def define_tables(self, metadata):
+ global base, child
+ base = Table('base', metadata,
+ Column('id', Integer, primary_key = True),
+ Column('type', String(40)))
+
+ child = Table('child', metadata,
+ Column('base_id', Integer, ForeignKey(base.c.id), primary_key = True),
+ Column('title', String(64)),
+ Column('parent_id', Integer, ForeignKey('child.base_id'))
+ )
+
+ def test_inherit_cond(self):
+ class Base(fixtures.Base):
+ pass
+ class Child(Base):
+ pass
+
+ mapper(Base, base, polymorphic_on=base.c.type)
+ mapper(Child, child, inherits=Base, inherit_condition=child.c.base_id==base.c.id, polymorphic_identity='c')
+
+ sess = create_session()
+ sess.save(Child(title='c1'))
+ sess.save(Child(title='c2'))
+ sess.flush()
+ sess.clear()
+ for inherit_cond in (child.c.base_id==base.c.id, base.c.id==child.c.base_id):
+ clear_mappers()
+ mapper(Base, base, polymorphic_on=base.c.type)
+ mapper(Child, child, inherits=Base, inherit_condition=inherit_cond, polymorphic_identity='c')
+
+ assert sess.query(Base).order_by(Base.id).all() == [Child(title='c1'), Child(title='c2')]
+
+
class ConstructionTest(ORMTest):
def define_tables(self, metadata):
global content_type, content, product