self._current_path = path
@_generative(__no_clauseelement_condition)
- def with_polymorphic(self, cls_or_mappers, selectable=None):
+ def with_polymorphic(self, cls_or_mappers, selectable=None, discriminator=None):
"""Load columns for descendant mappers of this Query's mapper.
Using this method will ensure that each descendant mapper's
instances will also have those columns already loaded so that
no "post fetch" of those columns will be required.
- ``cls_or_mappers`` is a single class or mapper, or list of class/mappers,
+ :param cls_or_mappers: - a single class or mapper, or list of class/mappers,
which inherit from this Query's mapper. Alternatively, it
may also be the string ``'*'``, in which case all descending
mappers will be added to the FROM clause.
- ``selectable`` is a table or select() statement that will
+ :param selectable: - a table or select() statement that will
be used in place of the generated FROM clause. This argument
is required if any of the desired mappers use concrete table
inheritance, since SQLAlchemy currently cannot generate UNIONs
will result in their table being appended directly to the FROM
clause which will usually lead to incorrect results.
+ :param discriminator: - a column to be used as the "discriminator"
+ column for the given selectable. If not given, the polymorphic_on
+ attribute of the mapper will be used, if any. This is useful
+ for mappers that don't have polymorphic loading behavior by default,
+ such as concrete table mappers.
+
"""
entity = self._generate_mapper_zero()
- entity.set_with_polymorphic(self, cls_or_mappers, selectable=selectable)
+ entity.set_with_polymorphic(self, cls_or_mappers, selectable=selectable, discriminator=discriminator)
@_generative()
def yield_per(self, count):
self.adapter = adapter
self.selectable = from_obj
self._with_polymorphic = with_polymorphic
+ self._polymorphic_discriminator = None
self.is_aliased_class = is_aliased_class
if is_aliased_class:
self.path_entity = self.entity = self.entity_zero = entity
self.path_entity = mapper.base_mapper
self.entity = self.entity_zero = mapper
- def set_with_polymorphic(self, query, cls_or_mappers, selectable):
+ def set_with_polymorphic(self, query, cls_or_mappers, selectable, discriminator):
if cls_or_mappers is None:
query._reset_polymorphic_adapter(self.mapper)
return
mappers, from_obj = self.mapper._with_polymorphic_args(cls_or_mappers, selectable)
self._with_polymorphic = mappers
+ self._polymorphic_discriminator = discriminator
# TODO: do the wrapped thing here too so that with_polymorphic() can be
# applied to aliases
if self.primary_entity:
_instance = self.mapper._instance_processor(context, (self.path_entity,), adapter,
- extension=self.extension, only_load_props=query._only_load_props, refresh_state=context.refresh_state
+ extension=self.extension, only_load_props=query._only_load_props, refresh_state=context.refresh_state,
+ polymorphic_discriminator=self._polymorphic_discriminator
)
else:
- _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter)
+ _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter,
+ polymorphic_discriminator=self._polymorphic_discriminator)
if custom_rows:
def main(context, row, result):
only_load_props=query._only_load_props,
column_collection=context.primary_columns
)
-
+
+ if self._polymorphic_discriminator:
+ if adapter:
+ pd = adapter.columns[self._polymorphic_discriminator]
+ else:
+ pd = self._polymorphic_discriminator
+ context.primary_columns.append(pd)
+
def __str__(self):
return str(self.mapper)
test_get_polymorphic = create_test(True, 'test_get_polymorphic')
test_get_nonpolymorphic = create_test(False, 'test_get_nonpolymorphic')
-class ConstructionTest(ORMTest):
- def define_tables(self, metadata):
- global content_type, content, product
- content_type = Table('content_type', metadata,
- Column('id', Integer, primary_key=True)
- )
- content = Table('content', metadata,
- Column('id', Integer, primary_key=True),
- Column('content_type_id', Integer, ForeignKey('content_type.id')),
- Column('type', String(30))
- )
- product = Table('product', metadata,
- Column('id', Integer, ForeignKey('content.id'), primary_key=True)
- )
-
- def testbasic(self):
- class ContentType(object): pass
- class Content(object): pass
- class Product(Content): pass
-
- content_types = mapper(ContentType, content_type)
- try:
- contents = mapper(Content, content, properties={
- 'content_type':relation(content_types)
- }, polymorphic_identity='contents')
- assert False
- except sa_exc.ArgumentError, e:
- assert str(e) == "Mapper 'Mapper|Content|content' specifies a polymorphic_identity of 'contents', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument"
-
- def testbackref(self):
- """tests adding a property to the superclass mapper"""
- class ContentType(object): pass
- class Content(object): pass
- class Product(Content): pass
-
- contents = mapper(Content, content, polymorphic_on=content.c.type, polymorphic_identity='content')
- products = mapper(Product, product, inherits=contents, polymorphic_identity='product')
- content_types = mapper(ContentType, content_type, properties={
- 'content':relation(contents, backref='contenttype')
- })
- p = Product()
- p.contenttype = ContentType()
- # TODO: assertion ??
-
class EagerLazyTest(ORMTest):
"""tests eager load/lazy load of child items off inheritance mappers, tests that
LazyLoader constructs the right query condition."""
from sqlalchemy.orm import exc as orm_exc
from testlib import *
from sqlalchemy.orm import attributes
+from testlib.testing import eq_
+
+class Employee(object):
+ def __init__(self, name):
+ self.name = name
+ def __repr__(self):
+ return self.__class__.__name__ + " " + self.name
+
+class Manager(Employee):
+ def __init__(self, name, manager_data):
+ self.name = name
+ self.manager_data = manager_data
+ def __repr__(self):
+ return self.__class__.__name__ + " " + self.name + " " + self.manager_data
+
+class Engineer(Employee):
+ def __init__(self, name, engineer_info):
+ self.name = name
+ self.engineer_info = engineer_info
+ def __repr__(self):
+ return self.__class__.__name__ + " " + self.name + " " + self.engineer_info
+
+class Hacker(Engineer):
+ def __init__(self, name, nickname, engineer_info):
+ self.name = name
+ self.nickname = nickname
+ self.engineer_info = engineer_info
+ def __repr__(self):
+ return self.__class__.__name__ + " " + self.name + " '" + \
+ self.nickname + "' " + self.engineer_info
+
+class Company(object):
+ pass
+
class ConcreteTest(ORMTest):
def define_tables(self, metadata):
)
def test_basic(self):
- class Employee(object):
- def __init__(self, name):
- self.name = name
- def __repr__(self):
- return self.__class__.__name__ + " " + self.name
-
- class Manager(Employee):
- def __init__(self, name, manager_data):
- self.name = name
- self.manager_data = manager_data
- def __repr__(self):
- return self.__class__.__name__ + " " + self.name + " " + self.manager_data
-
- class Engineer(Employee):
- def __init__(self, name, engineer_info):
- self.name = name
- self.engineer_info = engineer_info
- def __repr__(self):
- return self.__class__.__name__ + " " + self.name + " " + self.engineer_info
-
pjoin = polymorphic_union({
'manager':managers_table,
'engineer':engineers_table
session.flush()
session.clear()
- print set([repr(x) for x in session.query(Employee).all()])
- assert set([repr(x) for x in session.query(Employee).all()]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"])
- assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"])
- assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Kurt knows how to hack"])
+ assert set([repr(x) for x in session.query(Employee)]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"])
+ assert set([repr(x) for x in session.query(Manager)]) == set(["Manager Tom knows how to manage things"])
+ assert set([repr(x) for x in session.query(Engineer)]) == set(["Engineer Kurt knows how to hack"])
manager = session.query(Manager).one()
session.expire(manager, ['manager_data'])
self.assertEquals(manager.manager_data, "knows how to manage things")
def test_multi_level_no_base(self):
- class Employee(object):
- def __init__(self, name):
- self.name = name
- def __repr__(self):
- return self.__class__.__name__ + " " + self.name
-
- class Manager(Employee):
- def __init__(self, name, manager_data):
- self.name = name
- self.manager_data = manager_data
- def __repr__(self):
- return self.__class__.__name__ + " " + self.name + " " + self.manager_data
-
- class Engineer(Employee):
- def __init__(self, name, engineer_info):
- self.name = name
- self.engineer_info = engineer_info
- def __repr__(self):
- return self.__class__.__name__ + " " + self.name + " " + self.engineer_info
-
- class Hacker(Engineer):
- def __init__(self, name, nickname, engineer_info):
- self.name = name
- self.nickname = nickname
- self.engineer_info = engineer_info
- def __repr__(self):
- return self.__class__.__name__ + " " + self.name + " '" + \
- self.nickname + "' " + self.engineer_info
-
pjoin = polymorphic_union({
'manager': managers_table,
'engineer': engineers_table,
assert set([repr(x) for x in session.query(Hacker).all()]) == set(["Hacker Kurt 'Badass' knows how to hack"])
def test_multi_level_with_base(self):
- class Employee(object):
- def __init__(self, name):
- self.name = name
- def __repr__(self):
- return self.__class__.__name__ + " " + self.name
-
- class Manager(Employee):
- def __init__(self, name, manager_data):
- self.name = name
- self.manager_data = manager_data
- def __repr__(self):
- return self.__class__.__name__ + " " + self.name + " " + self.manager_data
-
- class Engineer(Employee):
- def __init__(self, name, engineer_info):
- self.name = name
- self.engineer_info = engineer_info
- def __repr__(self):
- return self.__class__.__name__ + " " + self.name + " " + self.engineer_info
-
- class Hacker(Engineer):
- def __init__(self, name, nickname, engineer_info):
- self.name = name
- self.nickname = nickname
- self.engineer_info = engineer_info
- def __repr__(self):
- return self.__class__.__name__ + " " + self.name + " '" + \
- self.nickname + "' " + self.engineer_info
-
pjoin = polymorphic_union({
'employee':employees_table,
'manager': managers_table,
session.add_all((tom, jerry, hacker))
session.flush()
- # ensure "readonly" on save logic didn't pollute the expired_attributes
- # collection
- assert 'nickname' not in attributes.instance_state(jerry).expired_attributes
- assert 'name' not in attributes.instance_state(jerry).expired_attributes
- assert 'name' not in attributes.instance_state(hacker).expired_attributes
- assert 'nickname' not in attributes.instance_state(hacker).expired_attributes
def go():
self.assertEquals(jerry.name, "Jerry")
self.assertEquals(hacker.nickname, "Badass")
# in the statement which is only against Employee's "pjoin"
assert len(testing.db.execute(session.query(Employee).with_labels().statement).fetchall()) == 3
- assert set([repr(x) for x in session.query(Employee).all()]) == set(["Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"])
- assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"])
- assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"])
- assert set([repr(x) for x in session.query(Hacker).all()]) == set(["Hacker Kurt 'Badass' knows how to hack"])
+ assert set([repr(x) for x in session.query(Employee)]) == set(["Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"])
+ assert set([repr(x) for x in session.query(Manager)]) == set(["Manager Tom knows how to manage things"])
+ assert set([repr(x) for x in session.query(Engineer)]) == set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"])
+ assert set([repr(x) for x in session.query(Hacker)]) == set(["Hacker Kurt 'Badass' knows how to hack"])
- def test_relation(self):
- class Employee(object):
- def __init__(self, name):
- self.name = name
- def __repr__(self):
- return self.__class__.__name__ + " " + self.name
-
- class Manager(Employee):
- def __init__(self, name, manager_data):
- self.name = name
- self.manager_data = manager_data
- def __repr__(self):
- return self.__class__.__name__ + " " + self.name + " " + self.manager_data
-
- class Engineer(Employee):
- def __init__(self, name, engineer_info):
- self.name = name
- self.engineer_info = engineer_info
- def __repr__(self):
- return self.__class__.__name__ + " " + self.name + " " + self.engineer_info
-
- class Company(object):
- pass
+
+ def test_without_default_polymorphic(self):
+ pjoin = polymorphic_union({
+ 'employee':employees_table,
+ 'manager': managers_table,
+ 'engineer': engineers_table,
+ 'hacker': hackers_table
+ }, 'type', 'pjoin')
+
+ pjoin2 = polymorphic_union({
+ 'engineer': engineers_table,
+ 'hacker': hackers_table
+ }, 'type', 'pjoin2')
+ employee_mapper = mapper(Employee, employees_table,
+ polymorphic_identity='employee')
+ manager_mapper = mapper(Manager, managers_table,
+ inherits=employee_mapper, concrete=True,
+ polymorphic_identity='manager')
+ engineer_mapper = mapper(Engineer, engineers_table,
+ inherits=employee_mapper, concrete=True,
+ polymorphic_identity='engineer')
+ hacker_mapper = mapper(Hacker, hackers_table,
+ inherits=engineer_mapper,
+ concrete=True, polymorphic_identity='hacker')
+
+ session = create_session()
+ jdoe = Employee('Jdoe')
+ tom = Manager('Tom', 'knows how to manage things')
+ jerry = Engineer('Jerry', 'knows how to program')
+ hacker = Hacker('Kurt', 'Badass', 'knows how to hack')
+ session.add_all((jdoe, tom, jerry, hacker))
+ session.flush()
+
+ eq_(
+ len(testing.db.execute(session.query(Employee).with_polymorphic('*', pjoin, pjoin.c.type).with_labels().statement).fetchall()),
+ 4
+ )
+
+ eq_(
+ session.query(Employee).get(jdoe.employee_id), jdoe
+ )
+ eq_(
+ session.query(Engineer).get(jerry.employee_id), jerry
+ )
+ eq_(
+ set([repr(x) for x in session.query(Employee).with_polymorphic('*', pjoin, pjoin.c.type)]),
+ set(["Employee Jdoe", "Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"])
+ )
+ eq_(
+ set([repr(x) for x in session.query(Manager)]),
+ set(["Manager Tom knows how to manage things"])
+ )
+ eq_(
+ set([repr(x) for x in session.query(Engineer).with_polymorphic('*', pjoin2, pjoin2.c.type)]),
+ set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"])
+ )
+ eq_(
+ set([repr(x) for x in session.query(Hacker)]),
+ set(["Hacker Kurt 'Badass' knows how to hack"])
+ )
+ # test adaption of the column by wrapping the query in a subquery
+ eq_(
+ len(testing.db.execute(
+ session.query(Engineer).with_polymorphic('*', pjoin2, pjoin2.c.type).from_self().statement
+ ).fetchall()),
+ 2
+ )
+ eq_(
+ set([repr(x) for x in session.query(Engineer).with_polymorphic('*', pjoin2, pjoin2.c.type).from_self()]),
+ set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"])
+ )
+
+ def test_relation(self):
pjoin = polymorphic_union({
'manager':managers_table,
'engineer':engineers_table
concrete=True, polymorphic_identity='refugee')
sess = create_session()
- assert sess.query(Refugee).get(1).name == "refugee1"
- assert sess.query(Refugee).get(2).name == "refugee2"
+ eq_(sess.query(Refugee).get(1).name, "refugee1")
+ eq_(sess.query(Refugee).get(2).name, "refugee2")
- assert sess.query(Office).get(1).name == "office1"
- assert sess.query(Office).get(2).name == "office2"
+ eq_(sess.query(Office).get(1).name, "office1")
+ eq_(sess.query(Office).get(2).name, "office2")
if __name__ == '__main__':
testenv.main()