]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fixed bug in eager loading on a many-to-one [ticket:96], added the ticket tests as...
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 6 Mar 2006 19:06:06 +0000 (19:06 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 6 Mar 2006 19:06:06 +0000 (19:06 +0000)
got eagerload1 to be a unit test also.

lib/sqlalchemy/mapping/properties.py
lib/sqlalchemy/util.py
test/alltests.py
test/eagertest1.py
test/eagertest2.py [new file with mode: 0644]

index 0f83568eddb892e6252ca1c2df649cf19f4b4830..617f3bbaf765acba56840e7efa0cca28bb35e8ff 100644 (file)
@@ -772,6 +772,11 @@ class EagerLoader(PropertyLoader):
         if not self.uselist:
             if isnew:
                 h.setattr_clean(self._instance(row, imap))
+            else:
+                # call _instance on the row, even though the object has been created,
+                # so that we further descend into properties
+                self._instance(row, imap)
+                
             return
         elif isnew:
             result_list = h
index 02bd5d587c57f039ba13b8b30aa992bbf8514838..7115dbcec47d2a0960e26f0c08dcaab990224705 100644 (file)
@@ -229,6 +229,8 @@ class DictDecorator(dict):
             return dict.__getitem__(self, key)
         except KeyError:
             return self.decorate[key]
+    def __repr__(self):
+        return dict.__repr__(self) + repr(self.decorate)
 class HashSet(object):
     """implements a Set."""
     def __init__(self, iter=None, ordered=False):
index d60ba6272bf6145ce6ec586b26a79a4a40a3c80c..3199b89f9193ff9a4cfe2b1495af09beacdade0b 100644 (file)
@@ -32,6 +32,8 @@ def suite():
         
         # ORM selecting
         'mapper',
+        'eagertest1',
+        'eagertest2',
         
         # ORM persistence
         'objectstore',
index ab4a69c1b64b68456ac81ed78c0e8fb24c420dad..5897e401621ec0a43dcaff4ba1fcf1f55d3c5532 100644 (file)
@@ -1,60 +1,74 @@
+from testbase import PersistTest, AssertMixin
+import testbase
+import unittest, sys, os
 from sqlalchemy import *
+import datetime
 
-class Part(object):pass
-class Design(object):pass
-class DesignType(object):pass
-class InheritedPart(object):pass
+class EagerTest(AssertMixin):
+    def setUpAll(self):
+        global designType, design, part, inheritedPart
+        
+        designType = Table('design_types', testbase.db, 
+               Column('design_type_id', Integer, primary_key=True),
+               )
 
-engine = create_engine('sqlite://', echo=True)
+        design =Table('design', testbase.db, 
+               Column('design_id', Integer, primary_key=True),
+               Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
 
-designType = Table('design_types', engine, 
-       Column('design_type_id', Integer, primary_key=True),
-       )
+        part = Table('parts', testbase.db, 
+               Column('part_id', Integer, primary_key=True),
+               Column('design_id', Integer, ForeignKey('design.design_id')),
+               Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
 
-design =Table('design', engine, 
-       Column('design_id', Integer, primary_key=True),
-       Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
+        inheritedPart = Table('inherited_part', testbase.db,
+               Column('ip_id', Integer, primary_key=True),
+               Column('part_id', Integer, ForeignKey('parts.part_id')),
+               Column('design_id', Integer, ForeignKey('design.design_id')),
+               )
 
-part = Table('parts', engine, 
-       Column('part_id', Integer, primary_key=True),
-       Column('design_id', Integer, ForeignKey('design.design_id')),
-       Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
+        designType.create()
+        design.create()
+        part.create()
+        inheritedPart.create()
+    def tearDownAll(self):
+        inheritedPart.drop()
+        part.drop()
+        design.drop()
+        designType.drop()
+    
+    def testone(self):
+        class Part(object):pass
+        class Design(object):pass
+        class DesignType(object):pass
+        class InheritedPart(object):pass
        
-inheritedPart = Table('inherited_part', engine,
-       Column('ip_id', Integer, primary_key=True),
-       Column('part_id', Integer, ForeignKey('parts.part_id')),
-       Column('design_id', Integer, ForeignKey('design.design_id')),
-       )
-
-designType.create()
-design.create()
-part.create()
-inheritedPart.create()
-       
-assign_mapper(Part, part)
-
-assign_mapper(InheritedPart, inheritedPart, properties=dict(
-       part=relation(Part, lazy=False)
-))
-
-assign_mapper(Design, design, properties=dict(
-       parts=relation(Part, private=True, backref="design"),
-       inheritedParts=relation(InheritedPart, private=True, backref="design"),
-))
-
-assign_mapper(DesignType, designType, properties=dict(
-#      designs=relation(Design, private=True, backref="type"),
-))
-
-Design.mapper.add_property("type", relation(DesignType, lazy=False, backref="designs"))
-Part.mapper.add_property("design", relation(Design, lazy=False, backref="parts"))
-#Part.mapper.add_property("designType", relation(DesignType))
-
-d = Design()
-objectstore.commit()
-objectstore.clear()
-print "lets go !\n\n\n"
-x = Design.get(1)
-x.inheritedParts
+        assign_mapper(Part, part)
+
+        assign_mapper(InheritedPart, inheritedPart, properties=dict(
+               part=relation(Part, lazy=False)
+        ))
+
+        assign_mapper(Design, design, properties=dict(
+               parts=relation(Part, private=True, backref="design"),
+               inheritedParts=relation(InheritedPart, private=True, backref="design"),
+        ))
+
+        assign_mapper(DesignType, designType, properties=dict(
+        #      designs=relation(Design, private=True, backref="type"),
+        ))
+
+        Design.mapper.add_property("type", relation(DesignType, lazy=False, backref="designs"))
+        Part.mapper.add_property("design", relation(Design, lazy=False, backref="parts"))
+        #Part.mapper.add_property("designType", relation(DesignType))
+
+        d = Design()
+        objectstore.commit()
+        objectstore.clear()
+        x = Design.get(1)
+        x.inheritedParts
+
+if __name__ == "__main__":    
+    testbase.main()
 
 
diff --git a/test/eagertest2.py b/test/eagertest2.py
new file mode 100644 (file)
index 0000000..430e12b
--- /dev/null
@@ -0,0 +1,254 @@
+from testbase import PersistTest, AssertMixin
+import testbase
+import unittest, sys, os
+from sqlalchemy import *
+import datetime
+
+db = testbase.db
+
+class EagerTest(AssertMixin):
+    def setUpAll(self):
+        objectstore.clear()
+        clear_mappers()
+        testbase.db.tables.clear()
+        
+        global companies_table, addresses_table, invoice_table, phones_table, items_table
+
+        companies_table = Table('companies', db,
+            Column('company_id', Integer, Sequence('company_id_seq', optional=True), primary_key = True),
+            Column('company_name', String(40)),
+
+        )
+        
+        addresses_table = Table('addresses', db,
+                                Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True),
+                                Column('company_id', Integer, ForeignKey("companies.company_id")),
+                                Column('address', String(40)),
+                                )
+
+        phones_table = Table('phone_numbers', db,
+                                Column('phone_id', Integer, Sequence('phone_id_seq', optional=True), primary_key = True),
+                                Column('address_id', Integer, ForeignKey('addresses.address_id')),
+                                Column('type', String(20)),
+                                Column('number', String(10)),
+                                )
+
+        invoice_table = Table('invoices', db,
+                              Column('invoice_id', Integer, Sequence('invoice_id_seq', optional=True), primary_key = True),
+                              Column('company_id', Integer, ForeignKey("companies.company_id")),
+                              Column('date', DateTime),   
+                              )
+
+        items_table = Table('items', db,
+                            Column('item_id', Integer, Sequence('item_id_seq', optional=True), primary_key = True),
+                            Column('invoice_id', Integer, ForeignKey('invoices.invoice_id')),
+                            Column('code', String(20)),
+                            Column('qty', Integer),
+                            )
+
+        companies_table.create()
+        addresses_table.create()
+        phones_table.create()
+        invoice_table.create()
+        items_table.create()
+        
+    def tearDownAll(self):
+        items_table.drop()
+        invoice_table.drop()
+        phones_table.drop()
+        addresses_table.drop()
+        companies_table.drop()
+
+    def tearDown(self):
+        objectstore.clear()
+        clear_mappers()
+        items_table.delete().execute()
+        invoice_table.delete().execute()
+        phones_table.delete().execute()
+        addresses_table.delete().execute()
+        companies_table.delete().execute()
+
+    def testone(self):
+        """tests eager load of a many-to-one attached to a one-to-many.  this testcase illustrated 
+        the bug, which is that when the single Company is loaded, no further processing of the rows
+        occurred in order to load the Company's second Address object."""
+        class Company(object):
+            def __init__(self):
+                self.company_id = None
+            def __repr__(self):
+                return "Company:" + repr(getattr(self, 'company_id', None)) + " " + repr(getattr(self, 'company_name', None)) + " " + str([repr(addr) for addr in self.addresses])
+
+        class Address(object):
+            def __repr__(self):
+                return "Address: " + repr(getattr(self, 'address_id', None)) + " " + repr(getattr(self, 'company_id', None)) + " " + repr(self.address)
+
+        class Invoice(object):
+            def __init__(self):
+                self.invoice_id = None
+            def __repr__(self):
+                return "Invoice:" + repr(getattr(self, 'invoice_id', None)) + " " + repr(getattr(self, 'date', None))  + " " + repr(self.company)
+
+        Address.mapper = mapper(Address, addresses_table, properties={
+            })
+        Company.mapper = mapper(Company, companies_table, properties={
+            'addresses' : relation(Address.mapper, lazy=False),
+            })
+        Invoice.mapper = mapper(Invoice, invoice_table, properties={
+            'company': relation(Company.mapper, lazy=False, )
+            })
+
+        c1 = Company()
+        c1.company_name = 'company 1'
+        a1 = Address()
+        a1.address = 'a1 address'
+        c1.addresses.append(a1)
+        a2 = Address()
+        a2.address = 'a2 address'
+        c1.addresses.append(a2)
+        i1 = Invoice()
+        i1.date = datetime.datetime.now()
+        i1.company = c1
+
+        
+        objectstore.commit()
+
+        company_id = c1.company_id
+        invoice_id = i1.invoice_id
+
+        objectstore.clear()
+
+        c = Company.mapper.get(company_id)
+
+        objectstore.clear()
+
+        i = Invoice.mapper.get(invoice_id)
+
+        self.echo(repr(c))
+        self.echo(repr(i.company))
+        self.assert_(repr(c) == repr(i.company))
+
+    def testtwo(self):
+        """this is the original testcase that includes various complicating factors"""
+        class Company(object):
+            def __init__(self):
+                self.company_id = None
+            def __repr__(self):
+                return "Company:" + repr(getattr(self, 'company_id', None)) + " " + repr(getattr(self, 'company_name', None)) + " " + str([repr(addr) for addr in self.addresses])
+
+        class Address(object):
+            def __repr__(self):
+                return "Address: " + repr(getattr(self, 'address_id', None)) + " " + repr(getattr(self, 'company_id', None)) + " " + repr(self.address) + str([repr(ph) for ph in self.phones])
+
+        class Phone(object):
+            def __repr__(self):
+                return "Phone: " + repr(getattr(self, 'phone_id', None)) + " " + repr(getattr(self, 'address_id', None)) + " " + repr(self.type) + " " + repr(self.number)
+
+        class Invoice(object):
+            def __init__(self):
+                self.invoice_id = None
+            def __repr__(self):
+                return "Invoice:" + repr(getattr(self, 'invoice_id', None)) + " " + repr(getattr(self, 'date', None))  + " " + repr(self.company) + " " + str([repr(item) for item in self.items])
+
+        class Item(object):
+            def __repr__(self):
+                return "Item: " + repr(getattr(self, 'item_id', None)) + " " + repr(getattr(self, 'invoice_id', None)) + " " + repr(self.code) + " " + repr(self.qty)
+
+        Phone.mapper = mapper(Phone, phones_table, is_primary=True)
+
+        Address.mapper = mapper(Address, addresses_table, properties={
+            'phones': relation(Phone.mapper, lazy=False, backref='address')
+            })
+
+        Company.mapper = mapper(Company, companies_table, properties={
+            'addresses' : relation(Address.mapper, lazy=False, backref='company'),
+            })
+
+        Item.mapper = mapper(Item, items_table, is_primary=True)
+
+        Invoice.mapper = mapper(Invoice, invoice_table, properties={
+            'items': relation(Item.mapper, lazy=False, backref='invoice'),
+            'company': relation(Company.mapper, lazy=False, backref='invoices')
+            })
+
+        objectstore.clear()
+        c1 = Company()
+        c1.company_name = 'company 1'
+
+        a1 = Address()
+        a1.address = 'a1 address'
+
+        p1 = Phone()
+        p1.type = 'home'
+        p1.number = '1111'
+
+        a1.phones.append(p1)
+
+        p2 = Phone()
+        p2.type = 'work'
+        p2.number = '22222'
+        a1.phones.append(p2)
+
+        c1.addresses.append(a1)
+
+        a2 = Address()
+        a2.address = 'a2 address'
+
+        p3 = Phone()
+        p3.type = 'home'
+        p3.number = '3333'
+        a2.phones.append(p3)
+
+        p4 = Phone()
+        p4.type = 'work'
+        p4.number = '44444'
+        a2.phones.append(p4)
+
+        c1.addresses.append(a2)
+
+        objectstore.commit()
+
+        company_id = c1.company_id
+        
+        objectstore.clear()
+
+        a = Company.mapper.get(company_id)
+        self.echo(repr(a))
+
+        # set up an invoice
+        i1 = Invoice()
+        i1.date = datetime.datetime.now()
+        i1.company = c1
+
+        item1 = Item()
+        item1.code = 'aaaa'
+        item1.qty = 1
+        item1.invoice = i1
+
+        item2 = Item()
+        item2.code = 'bbbb'
+        item2.qty = 2
+        item2.invoice = i1
+
+        item3 = Item()
+        item3.code = 'cccc'
+        item3.qty = 3
+        item3.invoice = i1
+
+        objectstore.commit()
+
+        invoice_id = i1.invoice_id
+
+        objectstore.clear()
+
+        c = Company.mapper.get(company_id)
+        self.echo(repr(c))
+
+        objectstore.clear()
+
+        i = Invoice.mapper.get(invoice_id)
+        self.echo(repr(i))
+
+        self.assert_(repr(i.company) == repr(c))
+        
+if __name__ == "__main__":    
+    testbase.main()