From 32aa76601daffba2c43b784f587c8d614e495324 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 6 Mar 2006 19:06:06 +0000 Subject: [PATCH] fixed bug in eager loading on a many-to-one [ticket:96], added the ticket tests as a unit test eagerload2. got eagerload1 to be a unit test also. --- lib/sqlalchemy/mapping/properties.py | 5 + lib/sqlalchemy/util.py | 2 + test/alltests.py | 2 + test/eagertest1.py | 116 ++++++------ test/eagertest2.py | 254 +++++++++++++++++++++++++++ 5 files changed, 328 insertions(+), 51 deletions(-) create mode 100644 test/eagertest2.py diff --git a/lib/sqlalchemy/mapping/properties.py b/lib/sqlalchemy/mapping/properties.py index 0f83568edd..617f3bbaf7 100644 --- a/lib/sqlalchemy/mapping/properties.py +++ b/lib/sqlalchemy/mapping/properties.py @@ -772,6 +772,11 @@ class EagerLoader(PropertyLoader): if not self.uselist: if isnew: h.setattr_clean(self._instance(row, imap)) + else: + # call _instance on the row, even though the object has been created, + # so that we further descend into properties + self._instance(row, imap) + return elif isnew: result_list = h diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 02bd5d587c..7115dbcec4 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -229,6 +229,8 @@ class DictDecorator(dict): return dict.__getitem__(self, key) except KeyError: return self.decorate[key] + def __repr__(self): + return dict.__repr__(self) + repr(self.decorate) class HashSet(object): """implements a Set.""" def __init__(self, iter=None, ordered=False): diff --git a/test/alltests.py b/test/alltests.py index d60ba6272b..3199b89f91 100644 --- a/test/alltests.py +++ b/test/alltests.py @@ -32,6 +32,8 @@ def suite(): # ORM selecting 'mapper', + 'eagertest1', + 'eagertest2', # ORM persistence 'objectstore', diff --git a/test/eagertest1.py b/test/eagertest1.py index ab4a69c1b6..5897e40162 100644 --- a/test/eagertest1.py +++ b/test/eagertest1.py @@ -1,60 +1,74 @@ +from testbase import PersistTest, AssertMixin +import testbase +import unittest, sys, os from sqlalchemy import * +import datetime -class Part(object):pass -class Design(object):pass -class DesignType(object):pass -class InheritedPart(object):pass +class EagerTest(AssertMixin): + def setUpAll(self): + global designType, design, part, inheritedPart + + designType = Table('design_types', testbase.db, + Column('design_type_id', Integer, primary_key=True), + ) -engine = create_engine('sqlite://', echo=True) + design =Table('design', testbase.db, + Column('design_id', Integer, primary_key=True), + Column('design_type_id', Integer, ForeignKey('design_types.design_type_id'))) -designType = Table('design_types', engine, - Column('design_type_id', Integer, primary_key=True), - ) + part = Table('parts', testbase.db, + Column('part_id', Integer, primary_key=True), + Column('design_id', Integer, ForeignKey('design.design_id')), + Column('design_type_id', Integer, ForeignKey('design_types.design_type_id'))) -design =Table('design', engine, - Column('design_id', Integer, primary_key=True), - Column('design_type_id', Integer, ForeignKey('design_types.design_type_id'))) + inheritedPart = Table('inherited_part', testbase.db, + Column('ip_id', Integer, primary_key=True), + Column('part_id', Integer, ForeignKey('parts.part_id')), + Column('design_id', Integer, ForeignKey('design.design_id')), + ) -part = Table('parts', engine, - Column('part_id', Integer, primary_key=True), - Column('design_id', Integer, ForeignKey('design.design_id')), - Column('design_type_id', Integer, ForeignKey('design_types.design_type_id'))) + designType.create() + design.create() + part.create() + inheritedPart.create() + def tearDownAll(self): + inheritedPart.drop() + part.drop() + design.drop() + designType.drop() + + def testone(self): + class Part(object):pass + class Design(object):pass + class DesignType(object):pass + class InheritedPart(object):pass -inheritedPart = Table('inherited_part', engine, - Column('ip_id', Integer, primary_key=True), - Column('part_id', Integer, ForeignKey('parts.part_id')), - Column('design_id', Integer, ForeignKey('design.design_id')), - ) - -designType.create() -design.create() -part.create() -inheritedPart.create() - -assign_mapper(Part, part) - -assign_mapper(InheritedPart, inheritedPart, properties=dict( - part=relation(Part, lazy=False) -)) - -assign_mapper(Design, design, properties=dict( - parts=relation(Part, private=True, backref="design"), - inheritedParts=relation(InheritedPart, private=True, backref="design"), -)) - -assign_mapper(DesignType, designType, properties=dict( -# designs=relation(Design, private=True, backref="type"), -)) - -Design.mapper.add_property("type", relation(DesignType, lazy=False, backref="designs")) -Part.mapper.add_property("design", relation(Design, lazy=False, backref="parts")) -#Part.mapper.add_property("designType", relation(DesignType)) - -d = Design() -objectstore.commit() -objectstore.clear() -print "lets go !\n\n\n" -x = Design.get(1) -x.inheritedParts + assign_mapper(Part, part) + + assign_mapper(InheritedPart, inheritedPart, properties=dict( + part=relation(Part, lazy=False) + )) + + assign_mapper(Design, design, properties=dict( + parts=relation(Part, private=True, backref="design"), + inheritedParts=relation(InheritedPart, private=True, backref="design"), + )) + + assign_mapper(DesignType, designType, properties=dict( + # designs=relation(Design, private=True, backref="type"), + )) + + Design.mapper.add_property("type", relation(DesignType, lazy=False, backref="designs")) + Part.mapper.add_property("design", relation(Design, lazy=False, backref="parts")) + #Part.mapper.add_property("designType", relation(DesignType)) + + d = Design() + objectstore.commit() + objectstore.clear() + x = Design.get(1) + x.inheritedParts + +if __name__ == "__main__": + testbase.main() diff --git a/test/eagertest2.py b/test/eagertest2.py new file mode 100644 index 0000000000..430e12ba72 --- /dev/null +++ b/test/eagertest2.py @@ -0,0 +1,254 @@ +from testbase import PersistTest, AssertMixin +import testbase +import unittest, sys, os +from sqlalchemy import * +import datetime + +db = testbase.db + +class EagerTest(AssertMixin): + def setUpAll(self): + objectstore.clear() + clear_mappers() + testbase.db.tables.clear() + + global companies_table, addresses_table, invoice_table, phones_table, items_table + + companies_table = Table('companies', db, + Column('company_id', Integer, Sequence('company_id_seq', optional=True), primary_key = True), + Column('company_name', String(40)), + + ) + + addresses_table = Table('addresses', db, + Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True), + Column('company_id', Integer, ForeignKey("companies.company_id")), + Column('address', String(40)), + ) + + phones_table = Table('phone_numbers', db, + Column('phone_id', Integer, Sequence('phone_id_seq', optional=True), primary_key = True), + Column('address_id', Integer, ForeignKey('addresses.address_id')), + Column('type', String(20)), + Column('number', String(10)), + ) + + invoice_table = Table('invoices', db, + Column('invoice_id', Integer, Sequence('invoice_id_seq', optional=True), primary_key = True), + Column('company_id', Integer, ForeignKey("companies.company_id")), + Column('date', DateTime), + ) + + items_table = Table('items', db, + Column('item_id', Integer, Sequence('item_id_seq', optional=True), primary_key = True), + Column('invoice_id', Integer, ForeignKey('invoices.invoice_id')), + Column('code', String(20)), + Column('qty', Integer), + ) + + companies_table.create() + addresses_table.create() + phones_table.create() + invoice_table.create() + items_table.create() + + def tearDownAll(self): + items_table.drop() + invoice_table.drop() + phones_table.drop() + addresses_table.drop() + companies_table.drop() + + def tearDown(self): + objectstore.clear() + clear_mappers() + items_table.delete().execute() + invoice_table.delete().execute() + phones_table.delete().execute() + addresses_table.delete().execute() + companies_table.delete().execute() + + def testone(self): + """tests eager load of a many-to-one attached to a one-to-many. this testcase illustrated + the bug, which is that when the single Company is loaded, no further processing of the rows + occurred in order to load the Company's second Address object.""" + class Company(object): + def __init__(self): + self.company_id = None + def __repr__(self): + return "Company:" + repr(getattr(self, 'company_id', None)) + " " + repr(getattr(self, 'company_name', None)) + " " + str([repr(addr) for addr in self.addresses]) + + class Address(object): + def __repr__(self): + return "Address: " + repr(getattr(self, 'address_id', None)) + " " + repr(getattr(self, 'company_id', None)) + " " + repr(self.address) + + class Invoice(object): + def __init__(self): + self.invoice_id = None + def __repr__(self): + return "Invoice:" + repr(getattr(self, 'invoice_id', None)) + " " + repr(getattr(self, 'date', None)) + " " + repr(self.company) + + Address.mapper = mapper(Address, addresses_table, properties={ + }) + Company.mapper = mapper(Company, companies_table, properties={ + 'addresses' : relation(Address.mapper, lazy=False), + }) + Invoice.mapper = mapper(Invoice, invoice_table, properties={ + 'company': relation(Company.mapper, lazy=False, ) + }) + + c1 = Company() + c1.company_name = 'company 1' + a1 = Address() + a1.address = 'a1 address' + c1.addresses.append(a1) + a2 = Address() + a2.address = 'a2 address' + c1.addresses.append(a2) + i1 = Invoice() + i1.date = datetime.datetime.now() + i1.company = c1 + + + objectstore.commit() + + company_id = c1.company_id + invoice_id = i1.invoice_id + + objectstore.clear() + + c = Company.mapper.get(company_id) + + objectstore.clear() + + i = Invoice.mapper.get(invoice_id) + + self.echo(repr(c)) + self.echo(repr(i.company)) + self.assert_(repr(c) == repr(i.company)) + + def testtwo(self): + """this is the original testcase that includes various complicating factors""" + class Company(object): + def __init__(self): + self.company_id = None + def __repr__(self): + return "Company:" + repr(getattr(self, 'company_id', None)) + " " + repr(getattr(self, 'company_name', None)) + " " + str([repr(addr) for addr in self.addresses]) + + class Address(object): + def __repr__(self): + return "Address: " + repr(getattr(self, 'address_id', None)) + " " + repr(getattr(self, 'company_id', None)) + " " + repr(self.address) + str([repr(ph) for ph in self.phones]) + + class Phone(object): + def __repr__(self): + return "Phone: " + repr(getattr(self, 'phone_id', None)) + " " + repr(getattr(self, 'address_id', None)) + " " + repr(self.type) + " " + repr(self.number) + + class Invoice(object): + def __init__(self): + self.invoice_id = None + def __repr__(self): + return "Invoice:" + repr(getattr(self, 'invoice_id', None)) + " " + repr(getattr(self, 'date', None)) + " " + repr(self.company) + " " + str([repr(item) for item in self.items]) + + class Item(object): + def __repr__(self): + return "Item: " + repr(getattr(self, 'item_id', None)) + " " + repr(getattr(self, 'invoice_id', None)) + " " + repr(self.code) + " " + repr(self.qty) + + Phone.mapper = mapper(Phone, phones_table, is_primary=True) + + Address.mapper = mapper(Address, addresses_table, properties={ + 'phones': relation(Phone.mapper, lazy=False, backref='address') + }) + + Company.mapper = mapper(Company, companies_table, properties={ + 'addresses' : relation(Address.mapper, lazy=False, backref='company'), + }) + + Item.mapper = mapper(Item, items_table, is_primary=True) + + Invoice.mapper = mapper(Invoice, invoice_table, properties={ + 'items': relation(Item.mapper, lazy=False, backref='invoice'), + 'company': relation(Company.mapper, lazy=False, backref='invoices') + }) + + objectstore.clear() + c1 = Company() + c1.company_name = 'company 1' + + a1 = Address() + a1.address = 'a1 address' + + p1 = Phone() + p1.type = 'home' + p1.number = '1111' + + a1.phones.append(p1) + + p2 = Phone() + p2.type = 'work' + p2.number = '22222' + a1.phones.append(p2) + + c1.addresses.append(a1) + + a2 = Address() + a2.address = 'a2 address' + + p3 = Phone() + p3.type = 'home' + p3.number = '3333' + a2.phones.append(p3) + + p4 = Phone() + p4.type = 'work' + p4.number = '44444' + a2.phones.append(p4) + + c1.addresses.append(a2) + + objectstore.commit() + + company_id = c1.company_id + + objectstore.clear() + + a = Company.mapper.get(company_id) + self.echo(repr(a)) + + # set up an invoice + i1 = Invoice() + i1.date = datetime.datetime.now() + i1.company = c1 + + item1 = Item() + item1.code = 'aaaa' + item1.qty = 1 + item1.invoice = i1 + + item2 = Item() + item2.code = 'bbbb' + item2.qty = 2 + item2.invoice = i1 + + item3 = Item() + item3.code = 'cccc' + item3.qty = 3 + item3.invoice = i1 + + objectstore.commit() + + invoice_id = i1.invoice_id + + objectstore.clear() + + c = Company.mapper.get(company_id) + self.echo(repr(c)) + + objectstore.clear() + + i = Invoice.mapper.get(invoice_id) + self.echo(repr(i)) + + self.assert_(repr(i.company) == repr(c)) + +if __name__ == "__main__": + testbase.main() -- 2.47.2