From fff9409a339d7a8d33667f5f717a3ae7ed334842 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 3 Dec 2008 21:27:04 +0000 Subject: [PATCH] - Query.with_polymorphic() now accepts a third argument "discriminator" which will replace the value of mapper.polymorphic_on for that query. Mappers themselves no longer require polymorphic_on to be set, even if the mapper has a polymorphic_identity. When not set, the mapper will load non-polymorphically by default. Together, these two features allow a non-polymorphic concrete inheritance setup to use polymorphic loading on a per-query basis, since concrete setups are prone to many issues when used polymorphically in all cases. --- CHANGES | 15 ++ lib/sqlalchemy/orm/mapper.py | 15 +- lib/sqlalchemy/orm/query.py | 33 +++-- test/orm/inheritance/basic.py | 44 ------ test/orm/inheritance/concrete.py | 237 +++++++++++++++---------------- 5 files changed, 161 insertions(+), 183 deletions(-) diff --git a/CHANGES b/CHANGES index f1245c7bfa..7ae9e7b60c 100644 --- a/CHANGES +++ b/CHANGES @@ -6,6 +6,21 @@ CHANGES ======= 0.5.0rc5 ======== +- new features +- orm + - Query.with_polymorphic() now accepts a third + argument "discriminator" which will replace + the value of mapper.polymorphic_on for that + query. Mappers themselves no longer require + polymorphic_on to be set, even if the mapper + has a polymorphic_identity. When not set, + the mapper will load non-polymorphically + by default. Together, these two features allow + a non-polymorphic concrete inheritance setup + to use polymorphic loading on a per-query basis, + since concrete setups are prone to many + issues when used polymorphically in all cases. + - bugfixes, behavioral changes - orm - Query.select_from(), from_statement() ensure diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 48c2f9e27f..9c62cadd9a 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -165,7 +165,7 @@ class Mapper(object): self.local_table = self.local_table.alias() if self.with_polymorphic and isinstance(self.with_polymorphic[1], expression._SelectBaseMixin): - self.with_polymorphic[1] = self.with_polymorphic[1].alias() + self.with_polymorphic = (self.with_polymorphic[0], self.with_polymorphic[1].alias()) # our 'polymorphic identity', a string name that when located in a result set row # indicates this Mapper should be used to construct the object instance for that row. @@ -270,20 +270,11 @@ class Mapper(object): if mapper.polymorphic_on: self.polymorphic_on = self.mapped_table.corresponding_column(mapper.polymorphic_on) break - else: - # TODO: this exception not covered - raise sa_exc.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', " - "but no mapper in it's hierarchy specifies " - "the 'polymorphic_on' column argument" % (self, self.polymorphic_identity)) else: self._all_tables = set() self.base_mapper = self self.mapped_table = self.local_table if self.polymorphic_identity: - if self.polymorphic_on is None: - raise sa_exc.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but " - "no mapper in it's hierarchy specifies the " - "'polymorphic_on' column argument" % (self, self.polymorphic_identity)) self.polymorphic_map[self.polymorphic_identity] = self self._identity_class = self.class_ @@ -1489,7 +1480,7 @@ class Mapper(object): # result set conversion - def _instance_processor(self, context, path, adapter, polymorphic_from=None, extension=None, only_load_props=None, refresh_state=None): + def _instance_processor(self, context, path, adapter, polymorphic_from=None, extension=None, only_load_props=None, refresh_state=None, polymorphic_discriminator=None): """Produce a mapper level row processor callable which processes rows into mapped instances.""" pk_cols = self.primary_key @@ -1497,7 +1488,7 @@ class Mapper(object): if polymorphic_from or refresh_state: polymorphic_on = None else: - polymorphic_on = self.polymorphic_on + polymorphic_on = polymorphic_discriminator or self.polymorphic_on polymorphic_instances = util.PopulateDict(self._configure_subclass_mapper(context, path, adapter)) version_id_col = self.version_id_col diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 4bff81d679..88357f34cb 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -358,7 +358,7 @@ class Query(object): self._current_path = path @_generative(__no_clauseelement_condition) - def with_polymorphic(self, cls_or_mappers, selectable=None): + def with_polymorphic(self, cls_or_mappers, selectable=None, discriminator=None): """Load columns for descendant mappers of this Query's mapper. Using this method will ensure that each descendant mapper's @@ -367,12 +367,12 @@ class Query(object): instances will also have those columns already loaded so that no "post fetch" of those columns will be required. - ``cls_or_mappers`` is a single class or mapper, or list of class/mappers, + :param cls_or_mappers: - a single class or mapper, or list of class/mappers, which inherit from this Query's mapper. Alternatively, it may also be the string ``'*'``, in which case all descending mappers will be added to the FROM clause. - ``selectable`` is a table or select() statement that will + :param selectable: - a table or select() statement that will be used in place of the generated FROM clause. This argument is required if any of the desired mappers use concrete table inheritance, since SQLAlchemy currently cannot generate UNIONs @@ -382,9 +382,15 @@ class Query(object): will result in their table being appended directly to the FROM clause which will usually lead to incorrect results. + :param discriminator: - a column to be used as the "discriminator" + column for the given selectable. If not given, the polymorphic_on + attribute of the mapper will be used, if any. This is useful + for mappers that don't have polymorphic loading behavior by default, + such as concrete table mappers. + """ entity = self._generate_mapper_zero() - entity.set_with_polymorphic(self, cls_or_mappers, selectable=selectable) + entity.set_with_polymorphic(self, cls_or_mappers, selectable=selectable, discriminator=discriminator) @_generative() def yield_per(self, count): @@ -1654,6 +1660,7 @@ class _MapperEntity(_QueryEntity): self.adapter = adapter self.selectable = from_obj self._with_polymorphic = with_polymorphic + self._polymorphic_discriminator = None self.is_aliased_class = is_aliased_class if is_aliased_class: self.path_entity = self.entity = self.entity_zero = entity @@ -1661,13 +1668,14 @@ class _MapperEntity(_QueryEntity): self.path_entity = mapper.base_mapper self.entity = self.entity_zero = mapper - def set_with_polymorphic(self, query, cls_or_mappers, selectable): + def set_with_polymorphic(self, query, cls_or_mappers, selectable, discriminator): if cls_or_mappers is None: query._reset_polymorphic_adapter(self.mapper) return mappers, from_obj = self.mapper._with_polymorphic_args(cls_or_mappers, selectable) self._with_polymorphic = mappers + self._polymorphic_discriminator = discriminator # TODO: do the wrapped thing here too so that with_polymorphic() can be # applied to aliases @@ -1718,10 +1726,12 @@ class _MapperEntity(_QueryEntity): if self.primary_entity: _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter, - extension=self.extension, only_load_props=query._only_load_props, refresh_state=context.refresh_state + extension=self.extension, only_load_props=query._only_load_props, refresh_state=context.refresh_state, + polymorphic_discriminator=self._polymorphic_discriminator ) else: - _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter) + _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter, + polymorphic_discriminator=self._polymorphic_discriminator) if custom_rows: def main(context, row, result): @@ -1759,7 +1769,14 @@ class _MapperEntity(_QueryEntity): only_load_props=query._only_load_props, column_collection=context.primary_columns ) - + + if self._polymorphic_discriminator: + if adapter: + pd = adapter.columns[self._polymorphic_discriminator] + else: + pd = self._polymorphic_discriminator + context.primary_columns.append(pd) + def __str__(self): return str(self.mapper) diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py index b7759aaeb3..8e51105c9e 100644 --- a/test/orm/inheritance/basic.py +++ b/test/orm/inheritance/basic.py @@ -274,50 +274,6 @@ class GetTest(ORMTest): test_get_polymorphic = create_test(True, 'test_get_polymorphic') test_get_nonpolymorphic = create_test(False, 'test_get_nonpolymorphic') -class ConstructionTest(ORMTest): - def define_tables(self, metadata): - global content_type, content, product - content_type = Table('content_type', metadata, - Column('id', Integer, primary_key=True) - ) - content = Table('content', metadata, - Column('id', Integer, primary_key=True), - Column('content_type_id', Integer, ForeignKey('content_type.id')), - Column('type', String(30)) - ) - product = Table('product', metadata, - Column('id', Integer, ForeignKey('content.id'), primary_key=True) - ) - - def testbasic(self): - class ContentType(object): pass - class Content(object): pass - class Product(Content): pass - - content_types = mapper(ContentType, content_type) - try: - contents = mapper(Content, content, properties={ - 'content_type':relation(content_types) - }, polymorphic_identity='contents') - assert False - except sa_exc.ArgumentError, e: - assert str(e) == "Mapper 'Mapper|Content|content' specifies a polymorphic_identity of 'contents', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" - - def testbackref(self): - """tests adding a property to the superclass mapper""" - class ContentType(object): pass - class Content(object): pass - class Product(Content): pass - - contents = mapper(Content, content, polymorphic_on=content.c.type, polymorphic_identity='content') - products = mapper(Product, product, inherits=contents, polymorphic_identity='product') - content_types = mapper(ContentType, content_type, properties={ - 'content':relation(contents, backref='contenttype') - }) - p = Product() - p.contenttype = ContentType() - # TODO: assertion ?? - class EagerLazyTest(ORMTest): """tests eager load/lazy load of child items off inheritance mappers, tests that LazyLoader constructs the right query condition.""" diff --git a/test/orm/inheritance/concrete.py b/test/orm/inheritance/concrete.py index e6277f3e91..c523232c94 100644 --- a/test/orm/inheritance/concrete.py +++ b/test/orm/inheritance/concrete.py @@ -4,6 +4,40 @@ from sqlalchemy.orm import * from sqlalchemy.orm import exc as orm_exc from testlib import * from sqlalchemy.orm import attributes +from testlib.testing import eq_ + +class Employee(object): + def __init__(self, name): + self.name = name + def __repr__(self): + return self.__class__.__name__ + " " + self.name + +class Manager(Employee): + def __init__(self, name, manager_data): + self.name = name + self.manager_data = manager_data + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " " + self.manager_data + +class Engineer(Employee): + def __init__(self, name, engineer_info): + self.name = name + self.engineer_info = engineer_info + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " " + self.engineer_info + +class Hacker(Engineer): + def __init__(self, name, nickname, engineer_info): + self.name = name + self.nickname = nickname + self.engineer_info = engineer_info + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " '" + \ + self.nickname + "' " + self.engineer_info + +class Company(object): + pass + class ConcreteTest(ORMTest): def define_tables(self, metadata): @@ -42,26 +76,6 @@ class ConcreteTest(ORMTest): ) def test_basic(self): - class Employee(object): - def __init__(self, name): - self.name = name - def __repr__(self): - return self.__class__.__name__ + " " + self.name - - class Manager(Employee): - def __init__(self, name, manager_data): - self.name = name - self.manager_data = manager_data - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " " + self.manager_data - - class Engineer(Employee): - def __init__(self, name, engineer_info): - self.name = name - self.engineer_info = engineer_info - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " " + self.engineer_info - pjoin = polymorphic_union({ 'manager':managers_table, 'engineer':engineers_table @@ -77,45 +91,15 @@ class ConcreteTest(ORMTest): session.flush() session.clear() - print set([repr(x) for x in session.query(Employee).all()]) - assert set([repr(x) for x in session.query(Employee).all()]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"]) - assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"]) - assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Kurt knows how to hack"]) + assert set([repr(x) for x in session.query(Employee)]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"]) + assert set([repr(x) for x in session.query(Manager)]) == set(["Manager Tom knows how to manage things"]) + assert set([repr(x) for x in session.query(Engineer)]) == set(["Engineer Kurt knows how to hack"]) manager = session.query(Manager).one() session.expire(manager, ['manager_data']) self.assertEquals(manager.manager_data, "knows how to manage things") def test_multi_level_no_base(self): - class Employee(object): - def __init__(self, name): - self.name = name - def __repr__(self): - return self.__class__.__name__ + " " + self.name - - class Manager(Employee): - def __init__(self, name, manager_data): - self.name = name - self.manager_data = manager_data - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " " + self.manager_data - - class Engineer(Employee): - def __init__(self, name, engineer_info): - self.name = name - self.engineer_info = engineer_info - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " " + self.engineer_info - - class Hacker(Engineer): - def __init__(self, name, nickname, engineer_info): - self.name = name - self.nickname = nickname - self.engineer_info = engineer_info - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " '" + \ - self.nickname + "' " + self.engineer_info - pjoin = polymorphic_union({ 'manager': managers_table, 'engineer': engineers_table, @@ -166,35 +150,6 @@ class ConcreteTest(ORMTest): assert set([repr(x) for x in session.query(Hacker).all()]) == set(["Hacker Kurt 'Badass' knows how to hack"]) def test_multi_level_with_base(self): - class Employee(object): - def __init__(self, name): - self.name = name - def __repr__(self): - return self.__class__.__name__ + " " + self.name - - class Manager(Employee): - def __init__(self, name, manager_data): - self.name = name - self.manager_data = manager_data - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " " + self.manager_data - - class Engineer(Employee): - def __init__(self, name, engineer_info): - self.name = name - self.engineer_info = engineer_info - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " " + self.engineer_info - - class Hacker(Engineer): - def __init__(self, name, nickname, engineer_info): - self.name = name - self.nickname = nickname - self.engineer_info = engineer_info - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " '" + \ - self.nickname + "' " + self.engineer_info - pjoin = polymorphic_union({ 'employee':employees_table, 'manager': managers_table, @@ -227,12 +182,6 @@ class ConcreteTest(ORMTest): session.add_all((tom, jerry, hacker)) session.flush() - # ensure "readonly" on save logic didn't pollute the expired_attributes - # collection - assert 'nickname' not in attributes.instance_state(jerry).expired_attributes - assert 'name' not in attributes.instance_state(jerry).expired_attributes - assert 'name' not in attributes.instance_state(hacker).expired_attributes - assert 'nickname' not in attributes.instance_state(hacker).expired_attributes def go(): self.assertEquals(jerry.name, "Jerry") self.assertEquals(hacker.nickname, "Badass") @@ -245,35 +194,85 @@ class ConcreteTest(ORMTest): # in the statement which is only against Employee's "pjoin" assert len(testing.db.execute(session.query(Employee).with_labels().statement).fetchall()) == 3 - assert set([repr(x) for x in session.query(Employee).all()]) == set(["Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"]) - assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"]) - assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"]) - assert set([repr(x) for x in session.query(Hacker).all()]) == set(["Hacker Kurt 'Badass' knows how to hack"]) + assert set([repr(x) for x in session.query(Employee)]) == set(["Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"]) + assert set([repr(x) for x in session.query(Manager)]) == set(["Manager Tom knows how to manage things"]) + assert set([repr(x) for x in session.query(Engineer)]) == set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"]) + assert set([repr(x) for x in session.query(Hacker)]) == set(["Hacker Kurt 'Badass' knows how to hack"]) - def test_relation(self): - class Employee(object): - def __init__(self, name): - self.name = name - def __repr__(self): - return self.__class__.__name__ + " " + self.name - - class Manager(Employee): - def __init__(self, name, manager_data): - self.name = name - self.manager_data = manager_data - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " " + self.manager_data - - class Engineer(Employee): - def __init__(self, name, engineer_info): - self.name = name - self.engineer_info = engineer_info - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " " + self.engineer_info - - class Company(object): - pass + + def test_without_default_polymorphic(self): + pjoin = polymorphic_union({ + 'employee':employees_table, + 'manager': managers_table, + 'engineer': engineers_table, + 'hacker': hackers_table + }, 'type', 'pjoin') + + pjoin2 = polymorphic_union({ + 'engineer': engineers_table, + 'hacker': hackers_table + }, 'type', 'pjoin2') + employee_mapper = mapper(Employee, employees_table, + polymorphic_identity='employee') + manager_mapper = mapper(Manager, managers_table, + inherits=employee_mapper, concrete=True, + polymorphic_identity='manager') + engineer_mapper = mapper(Engineer, engineers_table, + inherits=employee_mapper, concrete=True, + polymorphic_identity='engineer') + hacker_mapper = mapper(Hacker, hackers_table, + inherits=engineer_mapper, + concrete=True, polymorphic_identity='hacker') + + session = create_session() + jdoe = Employee('Jdoe') + tom = Manager('Tom', 'knows how to manage things') + jerry = Engineer('Jerry', 'knows how to program') + hacker = Hacker('Kurt', 'Badass', 'knows how to hack') + session.add_all((jdoe, tom, jerry, hacker)) + session.flush() + + eq_( + len(testing.db.execute(session.query(Employee).with_polymorphic('*', pjoin, pjoin.c.type).with_labels().statement).fetchall()), + 4 + ) + + eq_( + session.query(Employee).get(jdoe.employee_id), jdoe + ) + eq_( + session.query(Engineer).get(jerry.employee_id), jerry + ) + eq_( + set([repr(x) for x in session.query(Employee).with_polymorphic('*', pjoin, pjoin.c.type)]), + set(["Employee Jdoe", "Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"]) + ) + eq_( + set([repr(x) for x in session.query(Manager)]), + set(["Manager Tom knows how to manage things"]) + ) + eq_( + set([repr(x) for x in session.query(Engineer).with_polymorphic('*', pjoin2, pjoin2.c.type)]), + set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"]) + ) + eq_( + set([repr(x) for x in session.query(Hacker)]), + set(["Hacker Kurt 'Badass' knows how to hack"]) + ) + # test adaption of the column by wrapping the query in a subquery + eq_( + len(testing.db.execute( + session.query(Engineer).with_polymorphic('*', pjoin2, pjoin2.c.type).from_self().statement + ).fetchall()), + 2 + ) + eq_( + set([repr(x) for x in session.query(Engineer).with_polymorphic('*', pjoin2, pjoin2.c.type).from_self()]), + set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"]) + ) + + def test_relation(self): pjoin = polymorphic_union({ 'manager':managers_table, 'engineer':engineers_table @@ -342,11 +341,11 @@ class ColKeysTest(ORMTest): concrete=True, polymorphic_identity='refugee') sess = create_session() - assert sess.query(Refugee).get(1).name == "refugee1" - assert sess.query(Refugee).get(2).name == "refugee2" + eq_(sess.query(Refugee).get(1).name, "refugee1") + eq_(sess.query(Refugee).get(2).name, "refugee2") - assert sess.query(Office).get(1).name == "office1" - assert sess.query(Office).get(2).name == "office2" + eq_(sess.query(Office).get(1).name, "office1") + eq_(sess.query(Office).get(2).name, "office2") if __name__ == '__main__': testenv.main() -- 2.47.3