]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Query.with_polymorphic() now accepts a third
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 3 Dec 2008 21:27:04 +0000 (21:27 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 3 Dec 2008 21:27:04 +0000 (21:27 +0000)
argument "discriminator" which will replace
the value of mapper.polymorphic_on for that
query.  Mappers themselves no longer require
polymorphic_on to be set, even if the mapper
has a polymorphic_identity.   When not set,
the mapper will load non-polymorphically
by default. Together, these two features allow
a non-polymorphic concrete inheritance setup
to use polymorphic loading on a per-query basis,
since concrete setups are prone to many
issues when used polymorphically in all cases.

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
test/orm/inheritance/basic.py
test/orm/inheritance/concrete.py

diff --git a/CHANGES b/CHANGES
index f1245c7bfafc64293a2d2d7e05f041bf1d50cace..7ae9e7b60c5fe1a2ff95c8fea94c805d594f6e19 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -6,6 +6,21 @@ CHANGES
 =======
 0.5.0rc5
 ========
+- new features
+- orm
+    - Query.with_polymorphic() now accepts a third
+      argument "discriminator" which will replace
+      the value of mapper.polymorphic_on for that
+      query.  Mappers themselves no longer require
+      polymorphic_on to be set, even if the mapper
+      has a polymorphic_identity.   When not set,
+      the mapper will load non-polymorphically 
+      by default. Together, these two features allow 
+      a non-polymorphic concrete inheritance setup 
+      to use polymorphic loading on a per-query basis,
+      since concrete setups are prone to many
+      issues when used polymorphically in all cases.
+      
 - bugfixes, behavioral changes
 - orm
     - Query.select_from(), from_statement() ensure
index 48c2f9e27f590b8e0a198689d0feac7419972cf8..9c62cadd9a7c3d271932b5c567454a1c0b1689ab 100644 (file)
@@ -165,7 +165,7 @@ class Mapper(object):
             self.local_table = self.local_table.alias()
 
         if self.with_polymorphic and isinstance(self.with_polymorphic[1], expression._SelectBaseMixin):
-            self.with_polymorphic[1] = self.with_polymorphic[1].alias()
+            self.with_polymorphic = (self.with_polymorphic[0], self.with_polymorphic[1].alias())
 
         # our 'polymorphic identity', a string name that when located in a result set row
         # indicates this Mapper should be used to construct the object instance for that row.
@@ -270,20 +270,11 @@ class Mapper(object):
                         if mapper.polymorphic_on:
                             self.polymorphic_on = self.mapped_table.corresponding_column(mapper.polymorphic_on)
                             break
-                    else:
-                        # TODO: this exception not covered
-                        raise sa_exc.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', "
-                                    "but no mapper in it's hierarchy specifies "
-                                    "the 'polymorphic_on' column argument" % (self, self.polymorphic_identity))
         else:
             self._all_tables = set()
             self.base_mapper = self
             self.mapped_table = self.local_table
             if self.polymorphic_identity:
-                if self.polymorphic_on is None:
-                    raise sa_exc.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but "
-                                "no mapper in it's hierarchy specifies the "
-                                "'polymorphic_on' column argument" % (self, self.polymorphic_identity))
                 self.polymorphic_map[self.polymorphic_identity] = self
             self._identity_class = self.class_
 
@@ -1489,7 +1480,7 @@ class Mapper(object):
 
     # result set conversion
 
-    def _instance_processor(self, context, path, adapter, polymorphic_from=None, extension=None, only_load_props=None, refresh_state=None):
+    def _instance_processor(self, context, path, adapter, polymorphic_from=None, extension=None, only_load_props=None, refresh_state=None, polymorphic_discriminator=None):
         """Produce a mapper level row processor callable which processes rows into mapped instances."""
         
         pk_cols = self.primary_key
@@ -1497,7 +1488,7 @@ class Mapper(object):
         if polymorphic_from or refresh_state:
             polymorphic_on = None
         else:
-            polymorphic_on = self.polymorphic_on
+            polymorphic_on = polymorphic_discriminator or self.polymorphic_on
             polymorphic_instances = util.PopulateDict(self._configure_subclass_mapper(context, path, adapter))
 
         version_id_col = self.version_id_col
index 4bff81d6791980b41e534760fabfc80f0ffd9419..88357f34cb9c03259df05d16691d2ab40db9165f 100644 (file)
@@ -358,7 +358,7 @@ class Query(object):
         self._current_path = path
 
     @_generative(__no_clauseelement_condition)
-    def with_polymorphic(self, cls_or_mappers, selectable=None):
+    def with_polymorphic(self, cls_or_mappers, selectable=None, discriminator=None):
         """Load columns for descendant mappers of this Query's mapper.
 
         Using this method will ensure that each descendant mapper's
@@ -367,12 +367,12 @@ class Query(object):
         instances will also have those columns already loaded so that
         no "post fetch" of those columns will be required.
 
-        ``cls_or_mappers`` is a single class or mapper, or list of class/mappers,
+        :param cls_or_mappers: - a single class or mapper, or list of class/mappers,
         which inherit from this Query's mapper.  Alternatively, it
         may also be the string ``'*'``, in which case all descending
         mappers will be added to the FROM clause.
 
-        ``selectable`` is a table or select() statement that will
+        :param selectable: - a table or select() statement that will
         be used in place of the generated FROM clause.  This argument
         is required if any of the desired mappers use concrete table
         inheritance, since SQLAlchemy currently cannot generate UNIONs
@@ -382,9 +382,15 @@ class Query(object):
         will result in their table being appended directly to the FROM
         clause which will usually lead to incorrect results.
 
+        :param discriminator: - a column to be used as the "discriminator"
+        column for the given selectable.  If not given, the polymorphic_on
+        attribute of the mapper will be used, if any.   This is useful
+        for mappers that don't have polymorphic loading behavior by default,
+        such as concrete table mappers.
+        
         """
         entity = self._generate_mapper_zero()
-        entity.set_with_polymorphic(self, cls_or_mappers, selectable=selectable)
+        entity.set_with_polymorphic(self, cls_or_mappers, selectable=selectable, discriminator=discriminator)
 
     @_generative()
     def yield_per(self, count):
@@ -1654,6 +1660,7 @@ class _MapperEntity(_QueryEntity):
         self.adapter = adapter
         self.selectable  = from_obj
         self._with_polymorphic = with_polymorphic
+        self._polymorphic_discriminator = None
         self.is_aliased_class = is_aliased_class
         if is_aliased_class:
             self.path_entity = self.entity = self.entity_zero = entity
@@ -1661,13 +1668,14 @@ class _MapperEntity(_QueryEntity):
             self.path_entity = mapper.base_mapper
             self.entity = self.entity_zero = mapper
 
-    def set_with_polymorphic(self, query, cls_or_mappers, selectable):
+    def set_with_polymorphic(self, query, cls_or_mappers, selectable, discriminator):
         if cls_or_mappers is None:
             query._reset_polymorphic_adapter(self.mapper)
             return
 
         mappers, from_obj = self.mapper._with_polymorphic_args(cls_or_mappers, selectable)
         self._with_polymorphic = mappers
+        self._polymorphic_discriminator = discriminator
 
         # TODO: do the wrapped thing here too so that with_polymorphic() can be
         # applied to aliases
@@ -1718,10 +1726,12 @@ class _MapperEntity(_QueryEntity):
 
         if self.primary_entity:
             _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter,
-                extension=self.extension, only_load_props=query._only_load_props, refresh_state=context.refresh_state
+                extension=self.extension, only_load_props=query._only_load_props, refresh_state=context.refresh_state,
+                polymorphic_discriminator=self._polymorphic_discriminator
             )
         else:
-            _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter)
+            _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter,
+                             polymorphic_discriminator=self._polymorphic_discriminator)
 
         if custom_rows:
             def main(context, row, result):
@@ -1759,7 +1769,14 @@ class _MapperEntity(_QueryEntity):
                 only_load_props=query._only_load_props,
                 column_collection=context.primary_columns
             )
-
+        
+        if self._polymorphic_discriminator:
+            if adapter:
+                pd = adapter.columns[self._polymorphic_discriminator]
+            else:
+                pd = self._polymorphic_discriminator
+            context.primary_columns.append(pd)
+            
     def __str__(self):
         return str(self.mapper)
 
index b7759aaeb3712a15b158441051d9c8f291139fd9..8e51105c9e0ec585c27d8efac1625a99399ffd20 100644 (file)
@@ -274,50 +274,6 @@ class GetTest(ORMTest):
     test_get_polymorphic = create_test(True, 'test_get_polymorphic')
     test_get_nonpolymorphic = create_test(False, 'test_get_nonpolymorphic')
 
-class ConstructionTest(ORMTest):
-    def define_tables(self, metadata):
-        global content_type, content, product
-        content_type = Table('content_type', metadata,
-            Column('id', Integer, primary_key=True)
-            )
-        content = Table('content', metadata,
-            Column('id', Integer, primary_key=True),
-            Column('content_type_id', Integer, ForeignKey('content_type.id')),
-            Column('type', String(30))
-            )
-        product = Table('product', metadata,
-            Column('id', Integer, ForeignKey('content.id'), primary_key=True)
-        )
-
-    def testbasic(self):
-        class ContentType(object): pass
-        class Content(object): pass
-        class Product(Content): pass
-
-        content_types = mapper(ContentType, content_type)
-        try:
-            contents = mapper(Content, content, properties={
-                'content_type':relation(content_types)
-            }, polymorphic_identity='contents')
-            assert False
-        except sa_exc.ArgumentError, e:
-            assert str(e) == "Mapper 'Mapper|Content|content' specifies a polymorphic_identity of 'contents', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument"
-
-    def testbackref(self):
-        """tests adding a property to the superclass mapper"""
-        class ContentType(object): pass
-        class Content(object): pass
-        class Product(Content): pass
-
-        contents = mapper(Content, content, polymorphic_on=content.c.type, polymorphic_identity='content')
-        products = mapper(Product, product, inherits=contents, polymorphic_identity='product')
-        content_types = mapper(ContentType, content_type, properties={
-            'content':relation(contents, backref='contenttype')
-        })
-        p = Product()
-        p.contenttype = ContentType()
-        # TODO: assertion ??
-
 class EagerLazyTest(ORMTest):
     """tests eager load/lazy load of child items off inheritance mappers, tests that
     LazyLoader constructs the right query condition."""
index e6277f3e919978138a5fd2d86863673d42acabc3..c523232c945c2ced9ff55748ded9b0582c5a8db0 100644 (file)
@@ -4,6 +4,40 @@ from sqlalchemy.orm import *
 from sqlalchemy.orm import exc as orm_exc
 from testlib import *
 from sqlalchemy.orm import attributes
+from testlib.testing import eq_
+
+class Employee(object):
+    def __init__(self, name):
+        self.name = name
+    def __repr__(self):
+        return self.__class__.__name__ + " " + self.name
+
+class Manager(Employee):
+    def __init__(self, name, manager_data):
+        self.name = name
+        self.manager_data = manager_data
+    def __repr__(self):
+        return self.__class__.__name__ + " " + self.name + " " +  self.manager_data
+
+class Engineer(Employee):
+    def __init__(self, name, engineer_info):
+        self.name = name
+        self.engineer_info = engineer_info
+    def __repr__(self):
+        return self.__class__.__name__ + " " + self.name + " " +  self.engineer_info
+
+class Hacker(Engineer):
+    def __init__(self, name, nickname, engineer_info):
+        self.name = name
+        self.nickname = nickname
+        self.engineer_info = engineer_info
+    def __repr__(self):
+        return self.__class__.__name__ + " " + self.name + " '" + \
+               self.nickname + "' " +  self.engineer_info
+
+class Company(object):
+   pass
+
 
 class ConcreteTest(ORMTest):
     def define_tables(self, metadata):
@@ -42,26 +76,6 @@ class ConcreteTest(ORMTest):
         )
 
     def test_basic(self):
-        class Employee(object):
-            def __init__(self, name):
-                self.name = name
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name
-
-        class Manager(Employee):
-            def __init__(self, name, manager_data):
-                self.name = name
-                self.manager_data = manager_data
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name + " " +  self.manager_data
-
-        class Engineer(Employee):
-            def __init__(self, name, engineer_info):
-                self.name = name
-                self.engineer_info = engineer_info
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name + " " +  self.engineer_info
-
         pjoin = polymorphic_union({
             'manager':managers_table,
             'engineer':engineers_table
@@ -77,45 +91,15 @@ class ConcreteTest(ORMTest):
         session.flush()
         session.clear()
 
-        print set([repr(x) for x in session.query(Employee).all()])
-        assert set([repr(x) for x in session.query(Employee).all()]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"])
-        assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"])
-        assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Kurt knows how to hack"])
+        assert set([repr(x) for x in session.query(Employee)]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"])
+        assert set([repr(x) for x in session.query(Manager)]) == set(["Manager Tom knows how to manage things"])
+        assert set([repr(x) for x in session.query(Engineer)]) == set(["Engineer Kurt knows how to hack"])
 
         manager = session.query(Manager).one()
         session.expire(manager, ['manager_data'])
         self.assertEquals(manager.manager_data, "knows how to manage things")
 
     def test_multi_level_no_base(self):
-        class Employee(object):
-            def __init__(self, name):
-                self.name = name
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name
-
-        class Manager(Employee):
-            def __init__(self, name, manager_data):
-                self.name = name
-                self.manager_data = manager_data
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name + " " +  self.manager_data
-
-        class Engineer(Employee):
-            def __init__(self, name, engineer_info):
-                self.name = name
-                self.engineer_info = engineer_info
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name + " " +  self.engineer_info
-
-        class Hacker(Engineer):
-            def __init__(self, name, nickname, engineer_info):
-                self.name = name
-                self.nickname = nickname
-                self.engineer_info = engineer_info
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name + " '" + \
-                       self.nickname + "' " +  self.engineer_info
-
         pjoin = polymorphic_union({
             'manager': managers_table,
             'engineer': engineers_table,
@@ -166,35 +150,6 @@ class ConcreteTest(ORMTest):
         assert set([repr(x) for x in session.query(Hacker).all()]) == set(["Hacker Kurt 'Badass' knows how to hack"])
 
     def test_multi_level_with_base(self):
-        class Employee(object):
-            def __init__(self, name):
-                self.name = name
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name
-
-        class Manager(Employee):
-            def __init__(self, name, manager_data):
-                self.name = name
-                self.manager_data = manager_data
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name + " " +  self.manager_data
-
-        class Engineer(Employee):
-            def __init__(self, name, engineer_info):
-                self.name = name
-                self.engineer_info = engineer_info
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name + " " +  self.engineer_info
-
-        class Hacker(Engineer):
-            def __init__(self, name, nickname, engineer_info):
-                self.name = name
-                self.nickname = nickname
-                self.engineer_info = engineer_info
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name + " '" + \
-                       self.nickname + "' " +  self.engineer_info
-
         pjoin = polymorphic_union({
             'employee':employees_table,
             'manager': managers_table,
@@ -227,12 +182,6 @@ class ConcreteTest(ORMTest):
         session.add_all((tom, jerry, hacker))
         session.flush()
 
-        # ensure "readonly" on save logic didn't pollute the expired_attributes
-        # collection
-        assert 'nickname' not in attributes.instance_state(jerry).expired_attributes
-        assert 'name' not in attributes.instance_state(jerry).expired_attributes
-        assert 'name' not in attributes.instance_state(hacker).expired_attributes
-        assert 'nickname' not in attributes.instance_state(hacker).expired_attributes
         def go():
             self.assertEquals(jerry.name, "Jerry")
             self.assertEquals(hacker.nickname, "Badass")
@@ -245,35 +194,85 @@ class ConcreteTest(ORMTest):
         # in the statement which is only against Employee's "pjoin"
         assert len(testing.db.execute(session.query(Employee).with_labels().statement).fetchall()) == 3
         
-        assert set([repr(x) for x in session.query(Employee).all()]) == set(["Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"])
-        assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"])
-        assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"])
-        assert set([repr(x) for x in session.query(Hacker).all()]) == set(["Hacker Kurt 'Badass' knows how to hack"])
+        assert set([repr(x) for x in session.query(Employee)]) == set(["Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"])
+        assert set([repr(x) for x in session.query(Manager)]) == set(["Manager Tom knows how to manage things"])
+        assert set([repr(x) for x in session.query(Engineer)]) == set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"])
+        assert set([repr(x) for x in session.query(Hacker)]) == set(["Hacker Kurt 'Badass' knows how to hack"])
 
-    def test_relation(self):
-        class Employee(object):
-            def __init__(self, name):
-                self.name = name
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name
-
-        class Manager(Employee):
-            def __init__(self, name, manager_data):
-                self.name = name
-                self.manager_data = manager_data
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name + " " +  self.manager_data
-
-        class Engineer(Employee):
-            def __init__(self, name, engineer_info):
-                self.name = name
-                self.engineer_info = engineer_info
-            def __repr__(self):
-                return self.__class__.__name__ + " " + self.name + " " +  self.engineer_info
-
-        class Company(object):
-            pass
+    
+    def test_without_default_polymorphic(self):
+        pjoin = polymorphic_union({
+            'employee':employees_table,
+            'manager': managers_table,
+            'engineer': engineers_table,
+            'hacker': hackers_table
+        }, 'type', 'pjoin')
+
+        pjoin2 = polymorphic_union({
+            'engineer': engineers_table,
+            'hacker': hackers_table
+        }, 'type', 'pjoin2')
 
+        employee_mapper = mapper(Employee, employees_table, 
+                                polymorphic_identity='employee')
+        manager_mapper = mapper(Manager, managers_table, 
+                                inherits=employee_mapper, concrete=True, 
+                                polymorphic_identity='manager')
+        engineer_mapper = mapper(Engineer, engineers_table, 
+                                 inherits=employee_mapper, concrete=True,
+                                 polymorphic_identity='engineer')
+        hacker_mapper = mapper(Hacker, hackers_table, 
+                               inherits=engineer_mapper,
+                               concrete=True, polymorphic_identity='hacker')
+
+        session = create_session()
+        jdoe = Employee('Jdoe')
+        tom = Manager('Tom', 'knows how to manage things')
+        jerry = Engineer('Jerry', 'knows how to program')
+        hacker = Hacker('Kurt', 'Badass', 'knows how to hack')
+        session.add_all((jdoe, tom, jerry, hacker))
+        session.flush()
+
+        eq_(
+            len(testing.db.execute(session.query(Employee).with_polymorphic('*', pjoin, pjoin.c.type).with_labels().statement).fetchall()),
+            4
+        )
+        
+        eq_(
+            session.query(Employee).get(jdoe.employee_id), jdoe
+        )
+        eq_(
+            session.query(Engineer).get(jerry.employee_id), jerry
+        )
+        eq_(
+            set([repr(x) for x in session.query(Employee).with_polymorphic('*', pjoin, pjoin.c.type)]),
+            set(["Employee Jdoe", "Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"])
+        )
+        eq_(
+            set([repr(x) for x in session.query(Manager)]),
+            set(["Manager Tom knows how to manage things"])
+        )
+        eq_(
+            set([repr(x) for x in session.query(Engineer).with_polymorphic('*', pjoin2, pjoin2.c.type)]),
+            set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"])
+        )
+        eq_(
+            set([repr(x) for x in session.query(Hacker)]),
+            set(["Hacker Kurt 'Badass' knows how to hack"])
+        )
+        # test adaption of the column by wrapping the query in a subquery
+        eq_(
+            len(testing.db.execute(
+                session.query(Engineer).with_polymorphic('*', pjoin2, pjoin2.c.type).from_self().statement
+            ).fetchall()),
+            2
+        )
+        eq_(
+            set([repr(x) for x in session.query(Engineer).with_polymorphic('*', pjoin2, pjoin2.c.type).from_self()]),
+            set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"])
+        )
+        
+    def test_relation(self):
         pjoin = polymorphic_union({
             'manager':managers_table,
             'engineer':engineers_table
@@ -342,11 +341,11 @@ class ColKeysTest(ORMTest):
                                 concrete=True, polymorphic_identity='refugee')
 
         sess = create_session()
-        assert sess.query(Refugee).get(1).name == "refugee1"
-        assert sess.query(Refugee).get(2).name == "refugee2"
+        eq_(sess.query(Refugee).get(1).name, "refugee1")
+        eq_(sess.query(Refugee).get(2).name, "refugee2")
 
-        assert sess.query(Office).get(1).name == "office1"
-        assert sess.query(Office).get(2).name == "office2"
+        eq_(sess.query(Office).get(1).name, "office1")
+        eq_(sess.query(Office).get(2).name, "office2")
 
 if __name__ == '__main__':
     testenv.main()