]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fix for [ticket:712], more unit tests
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Aug 2007 06:00:45 +0000 (06:00 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Aug 2007 06:00:45 +0000 (06:00 +0000)
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/strategies.py
test/orm/inheritance/concrete.py
test/orm/inheritance/polymorph.py
test/orm/inheritance/polymorph2.py

index 4a19c3c911edc1c1abd15ba1a3f6daeb92164f84..fed7d2c928ba31e26bf58af8396b93562cec3ad4 100644 (file)
@@ -1395,7 +1395,7 @@ class Mapper(object):
                         context.attributes[('polymorphic_fetch', mapper)] = (self, [t for t in mapper.tables if t not in self.tables])
                     row = self.translate_row(mapper, row)
                     return mapper._instance(context, row, result=result, skip_polymorphic=True)
-                    
+        
         # look in main identity map.  if its there, we dont do anything to it,
         # including modifying any of its related items lists, as its already
         # been exposed to being modified by the application.
index 0afd24f70aed43a912f7c85616ac71f21a40d70e..43b95a0fdc934c04d3bbe356cd9b452abbccc91a 100644 (file)
@@ -558,7 +558,7 @@ class EagerLoader(AbstractRelationLoader):
         try:
             decorated_row = decorator(row)
             # check for identity key
-            identity_key = self.mapper.identity_key_from_row(decorated_row)
+            identity_key = self.select_mapper.identity_key_from_row(decorated_row)
             # and its good
             return decorator
         except KeyError, k:
@@ -587,13 +587,13 @@ class EagerLoader(AbstractRelationLoader):
                         # event handlers.
                         #
                         # FIXME: instead of...
-                        sessionlib.attribute_manager.get_attribute(instance, self.key).set_raw_value(instance, self.mapper._instance(selectcontext, decorated_row, None))
+                        sessionlib.attribute_manager.get_attribute(instance, self.key).set_raw_value(instance, self.select_mapper._instance(selectcontext, decorated_row, None))
                         # bypass and set directly:
                         #instance.__dict__[self.key] = ...
                     else:
                         # call _instance on the row, even though the object has been created,
                         # so that we further descend into properties
-                        self.mapper._instance(selectcontext, decorated_row, None)
+                        self.select_mapper._instance(selectcontext, decorated_row, None)
                 else:
                     if isnew:
                         if self._should_log_debug:
index d95a96da5f3cf2e401597ba8a4c1c48906b001bb..3443374d234551bc8aa27909719bc4b82230eea4 100644 (file)
@@ -3,22 +3,29 @@ from sqlalchemy import *
 from sqlalchemy.orm import *
 from testlib import *
 
-class ConcreteTest1(ORMTest):
+class ConcreteTest(ORMTest):
     def define_tables(self, metadata):
-        global managers_table, engineers_table
+        global managers_table, engineers_table, companies
+
+        companies = Table('companies', metadata, 
+           Column('id', Integer, primary_key=True),
+           Column('name', String(50)))
+        
         managers_table = Table('managers', metadata, 
             Column('employee_id', Integer, primary_key=True),
             Column('name', String(50)),
             Column('manager_data', String(50)),
+            Column('company_id', Integer, ForeignKey('companies.id'))
         )
 
         engineers_table = Table('engineers', metadata, 
             Column('employee_id', Integer, primary_key=True),
             Column('name', String(50)),
             Column('engineer_info', String(50)),
+            Column('company_id', Integer, ForeignKey('companies.id'))
         )
 
-    def testbasic(self):
+    def test_basic(self):
         class Employee(object):
             def __init__(self, name):
                 self.name = name
@@ -59,10 +66,55 @@ class ConcreteTest1(ORMTest):
         assert set([repr(x) for x in session.query(Manager).select()]) == set(["Manager Tom knows how to manage things"])
         assert set([repr(x) for x in session.query(Engineer).select()]) == set(["Engineer Kurt knows how to hack"])
 
-    def testwithrelation(self):
-        pass
+    def test_relation(self):
+        class Employee(object):
+            def __init__(self, name):
+                self.name = name
+            def __repr__(self):
+                return self.__class__.__name__ + " " + self.name
+
+        class Manager(Employee):
+            def __init__(self, name, manager_data):
+                self.name = name
+                self.manager_data = manager_data
+            def __repr__(self):
+                return self.__class__.__name__ + " " + self.name + " " +  self.manager_data
+
+        class Engineer(Employee):
+            def __init__(self, name, engineer_info):
+                self.name = name
+                self.engineer_info = engineer_info
+            def __repr__(self):
+                return self.__class__.__name__ + " " + self.name + " " +  self.engineer_info
+
+        class Company(object):
+            pass
+
+        pjoin = polymorphic_union({
+            'manager':managers_table,
+            'engineer':engineers_table
+        }, 'type', 'pjoin')
+
+        mapper(Company, companies, properties={
+            'employees':relation(Employee, lazy=False)
+        })
+        employee_mapper = mapper(Employee, pjoin, polymorphic_on=pjoin.c.type)
+        manager_mapper = mapper(Manager, managers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='manager')
+        engineer_mapper = mapper(Engineer, engineers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='engineer')
+
+        session = create_session()
+        c = Company()
+        c.employees.append(Manager('Tom', 'knows how to manage things'))
+        c.employees.append(Engineer('Kurt', 'knows how to hack'))
+        session.save(c)
+        session.flush()
+        session.clear()
+
+        def go():
+            c2 = session.query(Company).get(c.id)
+            assert set([repr(x) for x in c2.employees]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"])
+        self.assert_sql_count(testbase.db, go, 1)
         
-        # TODO: test a self-referential relationship on a concrete polymorphic mapping
 
 
 if __name__ == '__main__':
index 3eb2e032f0442ca036a222a48b092aace867c623..caee34b09dcc78a868617817563f10719e90ef3f 100644 (file)
@@ -263,14 +263,28 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
         session.flush()
         session.clear()
         id = c.company_id
-        c = session.query(Company).get(id)
-        for e in c.employees:
-            print e, e._instance_key, e.company
-        if include_base:
-            assert sets.Set([(e.get_name(), getattr(e, 'status', None)) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('joesmith', None), ('wally', 'CGG'), ('jsmith', 'ABA')])
+        def go():
+            c = session.query(Company).get(id)
+            for e in c.employees:
+                print e, e._instance_key, e.company
+                assert e._instance_key[0] == Person
+            if include_base:
+                assert sets.Set([(e.get_name(), getattr(e, 'status', None)) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('joesmith', None), ('wally', 'CGG'), ('jsmith', 'ABA')])
+            else:
+                assert sets.Set([(e.get_name(), e.status) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('wally', 'CGG'), ('jsmith', 'ABA')])
+            print "\n"
+
+        if not lazy_relation:
+            if polymorphic_fetch=='union':
+                self.assert_sql_count(testbase.db, go, 1)
+            else:
+                self.assert_sql_count(testbase.db, go, 5)
+                
         else:
-            assert sets.Set([(e.get_name(), e.status) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('wally', 'CGG'), ('jsmith', 'ABA')])
-        print "\n"
+            if polymorphic_fetch=='union':
+                self.assert_sql_count(testbase.db, go, 2)
+            else:
+                self.assert_sql_count(testbase.db, go, 6)
 
         # test selecting from the query, using the base mapped table (people) as the selection criterion.
         # in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join"
index a2f9c4a5f001e10bec52970bd40f5b47c439d4bd..8ed4b806a59fa8de1d782b8ae6604f0f96129116 100644 (file)
@@ -313,6 +313,7 @@ class RelationTest4(ORMTest):
          in the union.  however, the primaryjoin condition is going to be against the base table, and its a many-to-one
          relationship (unlike the test in polymorph.py) so the column in the base table is explicit.  Can the ClauseAdapter
          figure out how to alias the primaryjoin to the polymorphic union ?"""
+         
         # class definitions
         class Person(object):
             def __init__(self, **kwargs):
@@ -397,9 +398,12 @@ class RelationTest4(ORMTest):
         session.clear()
         print "-----------------------------------------------------------------"
         # and now for the lightning round, eager !
-        car1 = session.query(Car).options(eagerload('employee')).get(car1.car_id)
-        assert str(car1.employee) == "Engineer E4, status X"
 
+        def go():
+            testcar = session.query(Car).options(eagerload('employee')).get(car1.car_id)
+            assert str(testcar.employee) == "Engineer E4, status X"
+        self.assert_sql_count(testbase.db, go, 1)
+        
         session.clear()
         s = session.query(Car)
         c = s.join("employee").filter(Person.name=="E4")[0]