]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added 'fetchmode' capability to deferred polymorphic loading.
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 23 May 2007 00:12:01 +0000 (00:12 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 23 May 2007 00:12:01 +0000 (00:12 +0000)
can load immediately via second select or via deferred columns.
needs work to reduce complexity and possibly to improve upon unnecessary work
performed by ColumnLoader objects hitting lots of non-existent columns
- would like to add post_exec() step to MapperProperty...but need to devise some way
such that MapperProperty instances can register themselves in the SelectContext as
requiring post_exec....otherwise we add huge method call overhead (and there is too
much already)
- fix to deferred loading so that the attributes loaded by "group" deferred loading
get proper CommittedState
- some refactoring to attributes to support setting attributes as committed

examples/polymorph/polymorph.py
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
test/orm/inheritance/polymorph.py
test/orm/inheritance/polymorph2.py
test/orm/mapper.py

index 6c4f0aae6a732f46fcd7e04109b6d1acbe5e0411..00214f919d3b7a2b407e296b38642d491701a071 100644 (file)
@@ -1,10 +1,15 @@
 from sqlalchemy import *
 import sets
 
-# this example illustrates a polymorphic load of two classes, where each class has a very 
+import logging
+logging.basicConfig()
+logging.getLogger('sqlalchemy.orm').setLevel(logging.DEBUG)
+logging.getLogger('sqlalchemy.engine').setLevel(logging.DEBUG)
+
+# this example illustrates a polymorphic load of two classes, where each class has a  
 # different set of properties
 
-metadata = BoundMetaData('sqlite://', echo='True')
+metadata = BoundMetaData('sqlite://')
 
 # a table to store companies
 companies = Table('companies', metadata, 
@@ -63,8 +68,8 @@ person_join = polymorphic_union(
         'person':people.select(people.c.type=='person'),
     }, None, 'pjoin')
 
-#person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=person_join.c.type, polymorphic_identity='person')
-person_mapper = mapper(Person, people, select_table=person_join,polymorphic_on=person_join.c.type, polymorphic_identity='person')
+#person_mapper = mapper(Person, people, select_table=person_join,polymorphic_on=person_join.c.type, polymorphic_identity='person')
+person_mapper = mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person')
 mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer')
 mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager')
 
index 7a73ecca5c4af9e8414bc12de4849504afa38921..a1fa1726e8df12726c221a2b0104c7709543f334 100644 (file)
@@ -165,6 +165,9 @@ def lazyload(name):
 
     return strategies.EagerLazyOption(name, lazy=True)
 
+def fetchmode(name, type):
+    return strategies.FetchModeOption(name, type)
+    
 def noload(name):
     """Return a ``MapperOption`` that will convert the property of the
     given name into a non-load.
index 9f8a04db850fd40d7df1858d37532e327b205890..edcd7756bb6079752ee3ae6a19d0251a80afd2c3 100644 (file)
@@ -17,7 +17,8 @@ class InstrumentedAttribute(object):
     """
 
     PASSIVE_NORESULT = object()
-
+    ATTR_WAS_SET = object()
+    
     def __init__(self, manager, key, uselist, callable_, typecallable, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs):
         self.manager = manager
         self.key = key
@@ -52,6 +53,10 @@ class InstrumentedAttribute(object):
             return self
         return self.get(obj)
 
+    def get_instrument(cls, obj, key):
+        return getattr(obj.__class__, key)
+    get_instrument = classmethod(get_instrument)
+        
     def check_mutable_modified(self, obj):
         if self.mutable_scalars:
             h = self.get_history(obj, passive=True)
@@ -178,6 +183,28 @@ class InstrumentedAttribute(object):
             obj.__dict__[self.key] = None
             return None
 
+    def set_committed_value(self, obj, value):
+        """set an attribute value on the given instance and 'commit' it.
+        
+        this indicates that the given value is the "persisted" value,
+        and history will be logged only if a newly set value is not
+        equal to this value.
+        
+        this is typically used by deferred/lazy attribute loaders
+        to set object attributes after the initial load.
+        """
+        
+        state = obj._state
+        orig = state.get('original', None)
+        if self.uselist:
+            value = InstrumentedList(self, obj, value, init=False)
+        if orig is not None:
+            orig.commit_attribute(self, obj, value)
+        # remove per-instance callable, if any
+        state.pop(('callable', self), None)
+        obj.__dict__[self.key] = value
+        return value
+        
     def get(self, obj, passive=False, raiseerr=True):
         """Retrieve a value from the given object.
 
@@ -199,47 +226,27 @@ class InstrumentedAttribute(object):
                 trig()
                 return self.get(obj, passive=passive, raiseerr=raiseerr)
 
-            if self.uselist:
-                callable_ = self._get_callable(obj)
-                if callable_ is not None:
-                    if passive:
-                        return InstrumentedAttribute.PASSIVE_NORESULT
-                    self.logger.debug("Executing lazy callable on %s.%s" % (orm_util.instance_str(obj), self.key))
-                    values = callable_()
-                    l = InstrumentedList(self, obj, values, init=False)
-
-                    # if a callable was executed, then its part of the "committed state"
-                    # if any, so commit the newly loaded data
-                    orig = state.get('original', None)
-                    if orig is not None:
-                        orig.commit_attribute(self, obj, l)
-
+            callable_ = self._get_callable(obj)
+            if callable_ is not None:
+                if passive:
+                    return InstrumentedAttribute.PASSIVE_NORESULT
+                self.logger.debug("Executing lazy callable on %s.%s" % (orm_util.instance_str(obj), self.key))
+                value = callable_()
+                if value is not InstrumentedAttribute.ATTR_WAS_SET:
+                    return self.set_committed_value(obj, value)
                 else:
+                    return obj.__dict__[self.key]
+            else:
+                if self.uselist:
                     # note that we arent raising AttributeErrors, just creating a new
                     # blank list and setting it.
                     # this might be a good thing to be changeable by options.
-                    l = InstrumentedList(self, obj, self._blank_list(), init=False)
-                obj.__dict__[self.key] = l
-                return l
-            else:
-                callable_ = self._get_callable(obj)
-                if callable_ is not None:
-                    if passive:
-                        return InstrumentedAttribute.PASSIVE_NORESULT
-                    self.logger.debug("Executing lazy callable on %s.%s" % (orm_util.instance_str(obj), self.key))
-                    value = callable_()
-                    obj.__dict__[self.key] = value
-
-                    # if a callable was executed, then its part of the "committed state"
-                    # if any, so commit the newly loaded data
-                    orig = state.get('original', None)
-                    if orig is not None:
-                        orig.commit_attribute(self, obj)
-                    return value
+                    return self.set_committed_value(obj, self._blank_list())
                 else:
                     # note that we arent raising AttributeErrors, just returning None.
                     # this might be a good thing to be changeable by options.
-                    return None
+                    value = None
+                return value
 
     def set(self, event, obj, value):
         """Set a value on the given object.
index a9a26b57f9bd38217818e62a8e7e26ea7c07d1d4..c961b1b364edb22d12df53b3bc180f6e7deb2a88 100644 (file)
@@ -27,6 +27,11 @@ class MapperProperty(object):
 
         raise NotImplementedError()
 
+    def post_execute(self, selectcontext, instance):
+        """Called after all result rows have been received"""
+
+        raise NotImplementedError()
+        
     def cascade_iterator(self, type, object, recursive=None, halt_on=None):
         return []
 
index 6ae4fd647eff29931bb212ab22dff7455651abd2..b472abf0c78cd02e3da599746c45b799450ac8e4 100644 (file)
@@ -55,6 +55,7 @@ class Mapper(object):
                 polymorphic_on=None,
                 _polymorphic_map=None,
                 polymorphic_identity=None,
+                polymorphic_fetch=None,
                 concrete=False,
                 select_table=None,
                 allow_null_pks=False,
@@ -150,7 +151,14 @@ class Mapper(object):
           A value which will be stored in the Column denoted by
           polymorphic_on, corresponding to the *class identity* of
           this mapper.
-
+        
+        polymorphic_fetch
+          specifies how subclasses mapped through joined-table 
+          inheritance will be fetched.  options are 'union', 
+          'select', and 'deferred'.  if the select_table argument 
+          is present, defaults to 'union', otherwise defaults to
+          'select'.
+          
         concrete
           If True, indicates this mapper should use concrete table
           inheritance with its parent mapper.
@@ -227,6 +235,13 @@ class Mapper(object):
         # indicates this Mapper should be used to construct the object instance for that row.
         self.polymorphic_identity = polymorphic_identity
 
+        if polymorphic_fetch not in (None, 'union', 'select', 'deferred'):
+            raise exceptions.ArgumentError("Invalid option for 'polymorphic_fetch': '%s'" % polymorphic_fetch)
+        if polymorphic_fetch is None:
+            self.polymorphic_fetch = (self.select_table is None) and 'select' or 'union'
+        else:
+            self.polymorphic_fetch = polymorphic_fetch
+        
         # a dictionary of 'polymorphic identity' names, associating those names with
         # Mappers that will be used to construct object instances upon a select operation.
         if _polymorphic_map is None:
@@ -531,6 +546,12 @@ class Mapper(object):
             raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name))
 
         self.primary_key = primary_key
+
+        _get_clause = sql.and_()
+        for primary_key in self.primary_key:
+            _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type, unique=True))
+        self._get_clause = _get_clause
+
         
     def _compile_properties(self):
         """Inspect the properties dictionary sent to the Mapper's
@@ -1416,8 +1437,8 @@ class Mapper(object):
             if discriminator is not None:
                 mapper = self.polymorphic_map[discriminator]
                 if mapper is not self:
-                    if ('needsload', mapper) not in context.attributes:
-                        context.attributes[('needsload', mapper)] = (self, [t for t in mapper.tables if t not in self.tables])
+                    if ('polymorphic_fetch', mapper, self.polymorphic_fetch) not in context.attributes:
+                        context.attributes[('polymorphic_fetch', mapper, self.polymorphic_fetch)] = (self, [t for t in mapper.tables if t not in self.tables])
                     row = self.translate_row(mapper, row)
                     return mapper._instance(context, row, result=result, skip_polymorphic=True)
                     
@@ -1499,32 +1520,37 @@ class Mapper(object):
 
         return obj
 
+    def _deferred_inheritance_condition(self, needs_tables):
+        cond = self.inherit_condition.copy_container()
+
+        param_names = []
+        def visit_binary(binary):
+            leftcol = binary.left
+            rightcol = binary.right
+            if leftcol is None or rightcol is None:
+                return
+            if leftcol.table not in needs_tables:
+                binary.left = sql.bindparam(leftcol.name, None, type=binary.right.type, unique=True)
+                param_names.append(leftcol)
+            elif rightcol not in needs_tables:
+                binary.right = sql.bindparam(rightcol.name, None, type=binary.right.type, unique=True)
+                param_names.append(rightcol)
+        mapperutil.BinaryVisitor(visit_binary).traverse(cond)
+        return cond, param_names
+
     def _post_instance(self, context, instance):
-        (hosted_mapper, needs_tables) = context.attributes.get(('needsload', self), (None, None))
+        (hosted_mapper, needs_tables) = context.attributes.get(('polymorphic_fetch', self, 'select'), (None, None))
         if needs_tables is None or len(needs_tables) == 0:
             return
         
+        # TODO: this logic needs to be merged with the same logic in DeferredColumnLoader
         self.__log_debug("Post query loading instance " + mapperutil.instance_str(instance))
         if ('post_select', self) not in context.attributes:
-            cond = self.inherit_condition.copy_container()
-
-            param_names = []
-            def visit_binary(binary):
-                leftcol = binary.left
-                rightcol = binary.right
-                if leftcol is None or rightcol is None:
-                    return
-                if leftcol.table not in needs_tables:
-                    binary.left = sql.bindparam(leftcol.name, None, type=binary.right.type, unique=True)
-                    param_names.append(leftcol)
-                elif rightcol not in needs_tables:
-                    binary.right = sql.bindparam(rightcol.name, None, type=binary.right.type, unique=True)
-                    param_names.append(rightcol)
-            mapperutil.BinaryVisitor(visit_binary).traverse(cond)
-            statement = sql.select(needs_tables, cond)
+            cond, param_names = self._deferred_inheritance_condition(needs_tables)
+            statement = sql.select(needs_tables, cond, use_labels=True)
             context.attributes[('post_select', self)] = (statement, param_names)
             
-        (statement, binds) = context.attributes.get(('post_select', self))
+        (statement, binds) = context.attributes[('post_select', self)]
         
         identitykey = self.instance_key(instance)
         
index 6ed1a06d35d7803e611af556a5aa152e5b495939..9718f0adeb0acbdee6cd088ace14ca801befaf3d 100644 (file)
@@ -28,14 +28,10 @@ class Query(object):
         self.extension.append(self.mapper.extension)
         self.is_polymorphic = self.mapper is not self.select_mapper
         self._session = session
-        if not hasattr(self.mapper, '_get_clause'):
-            _get_clause = sql.and_()
-            for primary_key in self.primary_key_columns:
-                _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type, unique=True))
-            self.mapper._get_clause = _get_clause
             
         self._entities = []
-        self._get_clause = self.mapper._get_clause
+
+        self._get_clause = self.select_mapper._get_clause
 
         self._order_by = kwargs.pop('order_by', False)
         self._group_by = kwargs.pop('group_by', False)
index 1eeb77735b35d941aba7eee3069daa5ce36cf846..ed0b003db5b2d2493bc87f50b8d476ba8bbda943 100644 (file)
@@ -9,6 +9,7 @@
 from sqlalchemy import sql, schema, util, exceptions, sql_util, logging
 from sqlalchemy.orm import mapper, query
 from sqlalchemy.orm.interfaces import *
+from sqlalchemy.orm.attributes import InstrumentedAttribute
 from sqlalchemy.orm import session as sessionlib
 from sqlalchemy.orm import util as mapperutil
 import random
@@ -39,7 +40,9 @@ class ColumnLoader(LoaderStrategy):
             try:
                 instance.__dict__[self.key] = row[self.columns[0]]
             except KeyError:
-                pass
+                if self._should_log_debug:
+                    self.logger.debug("degrade to deferred column on %s" % mapperutil.attribute_str(instance, self.key))
+                self.parent_property._get_strategy(DeferredColumnLoader).process_row(selectcontext, instance, row, identitykey, isnew)
         
 ColumnLoader.logger = logging.class_logger(ColumnLoader)
 
@@ -66,58 +69,76 @@ class DeferredColumnLoader(LoaderStrategy):
     def process_row(self, selectcontext, instance, row, identitykey, isnew):
         if isnew:
             if not self.is_default or len(selectcontext.options):
-                sessionlib.attribute_manager.init_instance_attribute(instance, self.key, False, callable_=self.setup_loader(instance, selectcontext.options))
+                sessionlib.attribute_manager.init_instance_attribute(instance, self.key, False, callable_=self.setup_loader(instance, selectcontext))
             else:
                 sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
 
-    def setup_loader(self, instance, options=None):
-        if not mapper.has_mapper(instance):
+    def setup_loader(self, instance, context=None):
+        localparent = mapper.object_mapper(instance, raiseerror=False)
+        if localparent is None:
             return None
+            
+        prop = localparent.props[self.key]
+        if prop is not self.parent_property:
+            return prop._get_strategy(DeferredColumnLoader).setup_loader(instance)
+
+        if context is not None and ('polymorphic_fetch', localparent, 'deferred') in context.attributes:
+            (hosted_mapper, needs_tables) = context.attributes[('polymorphic_fetch', localparent, 'deferred')]
+            loadall = True
         else:
-            prop = mapper.object_mapper(instance).props[self.key]
-            if prop is not self.parent_property:
-                return prop._get_strategy(DeferredColumnLoader).setup_loader(instance)
-        def lazyload():
-            if self._should_log_debug:
-                self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), str(self.group)))
+            loadall = False
 
+        # clear context so it doesnt hang around attached to the instance
+        context = None
+        
+        def lazyload():
             if not mapper.has_identity(instance):
                 return None
 
-            try:
-                pk = self.parent.pks_by_table[self.columns[0].table]
-            except KeyError:
-                pk = self.columns[0].table.primary_key
-
-            clause = sql.and_()
-            for primary_key in pk:
-                attr = self.parent.get_attr_by_column(instance, primary_key)
-                if not attr:
-                    return None
-                clause.clauses.append(primary_key == attr)
+            if loadall:
+                # TODO: this logic needs to be merged with the same logic in Mapper
+                group = [p for p in localparent.props.values() if isinstance(p.strategy, ColumnLoader) and p.columns[0].table in needs_tables]
+            elif self.group is not None:
+                group = [p for p in localparent.props.values() if isinstance(p.strategy, DeferredColumnLoader) and p.group==self.group]
+            else:
+                group = None
+                
+            if self._should_log_debug:
+                self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), group and ','.join([p.key for p in group]) or 'None'))
 
             session = sessionlib.object_session(instance)
             if session is None:
                 raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
-
-            localparent = mapper.object_mapper(instance)
-            if self.group is not None:
-                groupcols = [p for p in localparent.props.values() if isinstance(p.strategy, DeferredColumnLoader) and p.group==self.group]
-                result = session.execute(localparent, sql.select([g.columns[0] for g in groupcols], clause, use_labels=True), None)
+                
+            if loadall:
+                # TODO: this logic needs to be merged with the same logic in Mapper
+                cond, param_names = localparent._deferred_inheritance_condition(needs_tables)
+                statement = sql.select(needs_tables, cond, use_labels=True)
+                params = {}
+                for c in param_names:
+                    params[c.name] = localparent.get_attr_by_column(instance, c)
+            else:
+                clause = localparent._get_clause
+                ident = instance._instance_key[1]
+                params = {}
+                for i, primary_key in enumerate(localparent.primary_key):
+                    params[primary_key._label] = ident[i]
+                if group is not None:
+                    statement = sql.select([p.columns[0] for p in group], clause, from_obj=[localparent.mapped_table], use_labels=True)
+                else:
+                    statement = sql.select([self.columns[0]], clause, from_obj=[localparent.mapped_table], use_labels=True)
+                    
+            if group is not None:
+                result = session.execute(localparent, statement, params)
                 try:
                     row = result.fetchone()
-                    for prop in groupcols:
-                        if prop is self:
-                            continue
-                        # set a scalar object instance directly on the object, 
-                        # bypassing SmartProperty event handlers.
-                        sessionlib.attribute_manager.init_instance_attribute(instance, prop.key, uselist=False)
-                        instance.__dict__[prop.key] = row[prop.columns[0]]
-                    return row[self.columns[0]]    
+                    for prop in group:
+                        InstrumentedAttribute.get_instrument(instance, prop.key).set_committed_value(instance, row[prop.columns[0]])
+                    return InstrumentedAttribute.ATTR_WAS_SET
                 finally:
                     result.close()
             else:
-                return session.scalar(localparent, sql.select([self.columns[0]], clause, use_labels=True),None)
+                return session.scalar(localparent, sql.select([self.columns[0]], clause, from_obj=[localparent.mapped_table], use_labels=True),params)
 
         return lazyload
                 
@@ -525,10 +546,10 @@ class EagerLoader(AbstractRelationLoader):
         """
         
         # check for a user-defined decorator in the SelectContext (which was set up by the contains_eager() option)
-        if selectcontext.attributes.has_key((EagerLoader, self.parent_property)):
+        if selectcontext.attributes.has_key(("eager_row_processor", self.parent_property)):
             # custom row decoration function, placed in the selectcontext by the 
             # contains_eager() mapper option
-            decorator = selectcontext.attributes[(EagerLoader, self.parent_property)]
+            decorator = selectcontext.attributes[("eager_row_processor", self.parent_property)]
             if decorator is None:
                 decorator = lambda row: row
         else:
@@ -589,7 +610,7 @@ class EagerLoader(AbstractRelationLoader):
                     self.logger.debug("eagerload scalar instance on %s" % mapperutil.attribute_str(instance, self.key))
                 if isnew:
                     # set a scalar object instance directly on the parent object, 
-                    # bypassing SmartProperty event handlers.
+                    # bypassing InstrumentedAttribute event handlers.
                     instance.__dict__[self.key] = self.mapper._instance(selectcontext, decorated_row, None)
                 else:
                     # call _instance on the row, even though the object has been created,
@@ -599,8 +620,9 @@ class EagerLoader(AbstractRelationLoader):
                 if isnew:
                     if self._should_log_debug:
                         self.logger.debug("initialize UniqueAppender on %s" % mapperutil.attribute_str(instance, self.key))
-                    # call the SmartProperty's initialize() method to create a new, blank list
-                    l = getattr(instance.__class__, self.key).initialize(instance)
+
+                    # call the InstrumentedAttribute's initialize() method to create a new, blank list
+                    l = InstrumentedAttribute.get_instrument(instance, self.key).initialize(instance)
                 
                     # create an appender object which will add set-like semantics to the list
                     appender = util.UniqueAppender(l.data)
@@ -639,6 +661,16 @@ class EagerLazyOption(StrategizedOption):
 
 EagerLazyOption.logger = logging.class_logger(EagerLazyOption)
 
+class FetchModeOption(PropertyOption):
+    def __init__(self, key, type):
+        super(FetchModeOption, self).__init__(key)
+        if type not in ('join', 'select'):
+            raise exceptions.ArgumentError("Fetchmode must be one of 'join' or 'select'")
+        self.type = type
+        
+    def process_selection_property(self, context, property):
+        context.attributes[('fetchmode', property)] = self.type
+        
 class RowDecorateOption(PropertyOption):
     def __init__(self, key, decorator=None, alias=None):
         super(RowDecorateOption, self).__init__(key)
@@ -655,7 +687,7 @@ class RowDecorateOption(PropertyOption):
                     d[c] = row[self.alias.corresponding_column(c)]
                 return d
             self.decorator = decorate
-        context.attributes[(EagerLoader, property)] = self.decorator
+        context.attributes[("eager_row_processor", property)] = self.decorator
 
 RowDecorateOption.logger = logging.class_logger(RowDecorateOption)
         
index 0ef9984ae55b55a3fb0534dd6e92d13efd683897..1ff608d24e66939bd8499dde9bbb4c27acd677f2 100644 (file)
@@ -195,7 +195,7 @@ class RelationToSubclassTest(PolymorphTest):
 class RoundTripTest(PolymorphTest):
     pass
           
-def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_colprop=False, use_literal_join=False, use_union=False):
+def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_colprop=False, use_literal_join=False, polymorphic_fetch=None):
     """generates a round trip test.
     
     include_base - whether or not to include the base 'person' type in the union.
@@ -205,7 +205,7 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
     """
     def test_roundtrip(self):
         # create a union that represents both types of joins.  
-        if not use_union:
+        if not polymorphic_fetch == 'union':
             person_join = None
         elif include_base:
             person_join = polymorphic_union(
@@ -222,9 +222,9 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
                 }, None, 'pjoin')
 
         if redefine_colprop:
-            person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=people.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name})
+            person_mapper = mapper(Person, people, select_table=person_join, polymorphic_fetch=polymorphic_fetch, polymorphic_on=people.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name})
         else:
-            person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=people.c.type, polymorphic_identity='person')
+            person_mapper = mapper(Person, people, select_table=person_join, polymorphic_fetch=polymorphic_fetch, polymorphic_on=people.c.type, polymorphic_identity='person')
         
         mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer')
         mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager')
@@ -304,7 +304,7 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
         (lazy_relation and "lazy" or "eager"),
         (include_base and "_inclbase" or ""),
         (redefine_colprop and "_redefcol" or ""),
-        (not use_union and "_nounion" or (use_literal_join and "_litjoin" or ""))
+        (polymorphic_fetch != 'union' and '_' + polymorphic_fetch or (use_literal_join and "_litjoin" or ""))
     )
     setattr(RoundTripTest, test_roundtrip.__name__, test_roundtrip)
 
@@ -312,8 +312,8 @@ for include_base in [True, False]:
     for lazy_relation in [True, False]:
         for redefine_colprop in [True, False]:
             for use_literal_join in [True, False]:
-                for use_union in [True, False]:
-                    generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, use_union)
+                for polymorphic_fetch in ['union', 'select', 'deferred']:
+                    generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, polymorphic_fetch)
                 
 if __name__ == "__main__":    
     testbase.main()
index fb704a495cd7ab641a78aabcdbc4289e268bf70f..bd95d4ff23fc1548f95f7aad8881bc7c1d3d074a 100644 (file)
@@ -374,6 +374,7 @@ class RelationTest4(testbase.ORMTest):
 
         # All print should output the same person (engineer E4)
         assert str(engineer4) == "Engineer E4, status X"
+        print str(usingGet)
         assert str(usingGet) == "Engineer E4, status X"
         assert str(usingProperty) == "Engineer E4, status X"
 
index 889a7c925972c896ae04a8224cf2e87aac9a111c..4cb721eeb52b68de53d53db20502391ab2934d74 100644 (file)
@@ -839,7 +839,8 @@ class DeferredTest(MapperSuperTest):
             'description':deferred(orders.c.description, group='primary'),
             'opened':deferred(orders.c.isopen, group='primary')
         })
-        q = create_session().query(m)
+        sess = create_session()
+        q = sess.query(m)
         def go():
             l = q.select()
             o2 = l[2]
@@ -853,6 +854,37 @@ class DeferredTest(MapperSuperTest):
             ("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3})
         ])
         
+        o2 = q.select()[2]
+#        assert o2.opened == 1
+        assert o2.description == 'order 3'
+        assert o2 not in sess.dirty
+        o2.description = 'order 3'
+        def go():
+            sess.flush()
+        self.assert_sql_count(db, go, 0)
+    
+    def testcommitsstate(self):
+        """test that when deferred elements are loaded via a group, they get the proper CommittedState
+        and dont result in changes being committed"""
+        
+        m = mapper(Order, orders, properties = {
+            'userident':deferred(orders.c.user_id, group='primary'),
+            'description':deferred(orders.c.description, group='primary'),
+            'opened':deferred(orders.c.isopen, group='primary')
+        })
+        sess = create_session()
+        q = sess.query(m)
+        o2 = q.select()[2]
+        # this will load the group of attributes
+        assert o2.description == 'order 3'
+        assert o2 not in sess.dirty
+        # this will mark it as 'dirty', but nothing actually changed
+        o2.description = 'order 3'
+        def go():
+            # therefore the flush() shouldnt actually issue any SQL
+            sess.flush()
+        self.assert_sql_count(db, go, 0)
+            
     def testoptions(self):
         """tests using options on a mapper to create deferred and undeferred columns"""
         m = mapper(Order, orders)