From: Mike Bayer Date: Sat, 4 Aug 2007 06:00:45 +0000 (+0000) Subject: - fix for [ticket:712], more unit tests X-Git-Tag: rel_0_4beta1~83 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=71df774b13c347760727d850acb1d4498b8cc5d7;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - fix for [ticket:712], more unit tests --- diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 4a19c3c911..fed7d2c928 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1395,7 +1395,7 @@ class Mapper(object): context.attributes[('polymorphic_fetch', mapper)] = (self, [t for t in mapper.tables if t not in self.tables]) row = self.translate_row(mapper, row) return mapper._instance(context, row, result=result, skip_polymorphic=True) - + # look in main identity map. if its there, we dont do anything to it, # including modifying any of its related items lists, as its already # been exposed to being modified by the application. diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 0afd24f70a..43b95a0fdc 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -558,7 +558,7 @@ class EagerLoader(AbstractRelationLoader): try: decorated_row = decorator(row) # check for identity key - identity_key = self.mapper.identity_key_from_row(decorated_row) + identity_key = self.select_mapper.identity_key_from_row(decorated_row) # and its good return decorator except KeyError, k: @@ -587,13 +587,13 @@ class EagerLoader(AbstractRelationLoader): # event handlers. # # FIXME: instead of... - sessionlib.attribute_manager.get_attribute(instance, self.key).set_raw_value(instance, self.mapper._instance(selectcontext, decorated_row, None)) + sessionlib.attribute_manager.get_attribute(instance, self.key).set_raw_value(instance, self.select_mapper._instance(selectcontext, decorated_row, None)) # bypass and set directly: #instance.__dict__[self.key] = ... else: # call _instance on the row, even though the object has been created, # so that we further descend into properties - self.mapper._instance(selectcontext, decorated_row, None) + self.select_mapper._instance(selectcontext, decorated_row, None) else: if isnew: if self._should_log_debug: diff --git a/test/orm/inheritance/concrete.py b/test/orm/inheritance/concrete.py index d95a96da5f..3443374d23 100644 --- a/test/orm/inheritance/concrete.py +++ b/test/orm/inheritance/concrete.py @@ -3,22 +3,29 @@ from sqlalchemy import * from sqlalchemy.orm import * from testlib import * -class ConcreteTest1(ORMTest): +class ConcreteTest(ORMTest): def define_tables(self, metadata): - global managers_table, engineers_table + global managers_table, engineers_table, companies + + companies = Table('companies', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50))) + managers_table = Table('managers', metadata, Column('employee_id', Integer, primary_key=True), Column('name', String(50)), Column('manager_data', String(50)), + Column('company_id', Integer, ForeignKey('companies.id')) ) engineers_table = Table('engineers', metadata, Column('employee_id', Integer, primary_key=True), Column('name', String(50)), Column('engineer_info', String(50)), + Column('company_id', Integer, ForeignKey('companies.id')) ) - def testbasic(self): + def test_basic(self): class Employee(object): def __init__(self, name): self.name = name @@ -59,10 +66,55 @@ class ConcreteTest1(ORMTest): assert set([repr(x) for x in session.query(Manager).select()]) == set(["Manager Tom knows how to manage things"]) assert set([repr(x) for x in session.query(Engineer).select()]) == set(["Engineer Kurt knows how to hack"]) - def testwithrelation(self): - pass + def test_relation(self): + class Employee(object): + def __init__(self, name): + self.name = name + def __repr__(self): + return self.__class__.__name__ + " " + self.name + + class Manager(Employee): + def __init__(self, name, manager_data): + self.name = name + self.manager_data = manager_data + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " " + self.manager_data + + class Engineer(Employee): + def __init__(self, name, engineer_info): + self.name = name + self.engineer_info = engineer_info + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " " + self.engineer_info + + class Company(object): + pass + + pjoin = polymorphic_union({ + 'manager':managers_table, + 'engineer':engineers_table + }, 'type', 'pjoin') + + mapper(Company, companies, properties={ + 'employees':relation(Employee, lazy=False) + }) + employee_mapper = mapper(Employee, pjoin, polymorphic_on=pjoin.c.type) + manager_mapper = mapper(Manager, managers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='manager') + engineer_mapper = mapper(Engineer, engineers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='engineer') + + session = create_session() + c = Company() + c.employees.append(Manager('Tom', 'knows how to manage things')) + c.employees.append(Engineer('Kurt', 'knows how to hack')) + session.save(c) + session.flush() + session.clear() + + def go(): + c2 = session.query(Company).get(c.id) + assert set([repr(x) for x in c2.employees]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"]) + self.assert_sql_count(testbase.db, go, 1) - # TODO: test a self-referential relationship on a concrete polymorphic mapping if __name__ == '__main__': diff --git a/test/orm/inheritance/polymorph.py b/test/orm/inheritance/polymorph.py index 3eb2e032f0..caee34b09d 100644 --- a/test/orm/inheritance/polymorph.py +++ b/test/orm/inheritance/polymorph.py @@ -263,14 +263,28 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co session.flush() session.clear() id = c.company_id - c = session.query(Company).get(id) - for e in c.employees: - print e, e._instance_key, e.company - if include_base: - assert sets.Set([(e.get_name(), getattr(e, 'status', None)) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('joesmith', None), ('wally', 'CGG'), ('jsmith', 'ABA')]) + def go(): + c = session.query(Company).get(id) + for e in c.employees: + print e, e._instance_key, e.company + assert e._instance_key[0] == Person + if include_base: + assert sets.Set([(e.get_name(), getattr(e, 'status', None)) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('joesmith', None), ('wally', 'CGG'), ('jsmith', 'ABA')]) + else: + assert sets.Set([(e.get_name(), e.status) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('wally', 'CGG'), ('jsmith', 'ABA')]) + print "\n" + + if not lazy_relation: + if polymorphic_fetch=='union': + self.assert_sql_count(testbase.db, go, 1) + else: + self.assert_sql_count(testbase.db, go, 5) + else: - assert sets.Set([(e.get_name(), e.status) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('wally', 'CGG'), ('jsmith', 'ABA')]) - print "\n" + if polymorphic_fetch=='union': + self.assert_sql_count(testbase.db, go, 2) + else: + self.assert_sql_count(testbase.db, go, 6) # test selecting from the query, using the base mapped table (people) as the selection criterion. # in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join" diff --git a/test/orm/inheritance/polymorph2.py b/test/orm/inheritance/polymorph2.py index a2f9c4a5f0..8ed4b806a5 100644 --- a/test/orm/inheritance/polymorph2.py +++ b/test/orm/inheritance/polymorph2.py @@ -313,6 +313,7 @@ class RelationTest4(ORMTest): in the union. however, the primaryjoin condition is going to be against the base table, and its a many-to-one relationship (unlike the test in polymorph.py) so the column in the base table is explicit. Can the ClauseAdapter figure out how to alias the primaryjoin to the polymorphic union ?""" + # class definitions class Person(object): def __init__(self, **kwargs): @@ -397,9 +398,12 @@ class RelationTest4(ORMTest): session.clear() print "-----------------------------------------------------------------" # and now for the lightning round, eager ! - car1 = session.query(Car).options(eagerload('employee')).get(car1.car_id) - assert str(car1.employee) == "Engineer E4, status X" + def go(): + testcar = session.query(Car).options(eagerload('employee')).get(car1.car_id) + assert str(testcar.employee) == "Engineer E4, status X" + self.assert_sql_count(testbase.db, go, 1) + session.clear() s = session.query(Car) c = s.join("employee").filter(Person.name=="E4")[0]