]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- reworked all lazy/deferred/expired callables to be
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 21 Dec 2007 06:57:20 +0000 (06:57 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 21 Dec 2007 06:57:20 +0000 (06:57 +0000)
serializable class instances, added pickling tests
- cleaned up "deferred" polymorphic system so that the
mapper handles it entirely
- columns which are missing from a Query's select statement
now get automatically deferred during load.

13 files changed:
CHANGES
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/expression.py
test/orm/alltests.py
test/orm/attributes.py
test/orm/expire.py
test/orm/mapper.py
test/orm/pickled.py [new file with mode: 0644]

diff --git a/CHANGES b/CHANGES
index d8075c921d7a674519bf1e290de59cb425c96aa7..68a73eb8e9da0d4ef66ce1e59709970ef388463d 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -99,7 +99,14 @@ CHANGES
      have each method called only once per operation, use the same 
      instance of the extension for both mappers.
      [ticket:490]
+
+   - columns which are missing from a Query's select statement
+     now get automatically deferred during load.
      
+   - improved support for pickling of mapped entities.  Per-instance
+     lazy/deferred/expired callables are now serializable so that
+     they serialize and deserialize with _state. 
+       
    - new synonym() behavior: an attribute will be placed on the mapped
      class, if one does not exist already, in all cases. if a property
      already exists on the class, the synonym will decorate the property
index 089522673c69126fd1bfc2456c50687873ac3c44..135269906bd23969ce48758ac6e59784563de077 100644 (file)
@@ -255,12 +255,13 @@ class AttributeImpl(object):
 class ScalarAttributeImpl(AttributeImpl):
     """represents a scalar value-holding InstrumentedAttribute."""
 
-    accepts_global_callable = True
+    accepts_scalar_loader = True
     
     def delete(self, state):
         if self.key not in state.committed_state:
             state.committed_state[self.key] = state.dict.get(self.key, NO_VALUE)
 
+        # TODO: catch key errors, convert to attributeerror?
         del state.dict[self.key]
         state.modified=True
 
@@ -327,7 +328,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
     Adds events to delete/set operations.
     """
 
-    accepts_global_callable = False
+    accepts_scalar_loader = False
 
     def __init__(self, class_, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
         super(ScalarObjectAttributeImpl, self).__init__(class_, key,
@@ -338,6 +339,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         
     def delete(self, state):
         old = self.get(state)
+        # TODO: catch key errors, convert to attributeerror?
         del state.dict[self.key]
         self.fire_remove_event(state, old, self)
 
@@ -404,7 +406,7 @@ class CollectionAttributeImpl(AttributeImpl):
     CollectionAdapter, a "view" onto that object that presents consistent
     bag semantics to the orm layer independent of the user data implementation.
     """
-    accepts_global_callable = False
+    accepts_scalar_loader = False
     
     def __init__(self, class_, key, callable_, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
         super(CollectionAttributeImpl, self).__init__(class_, 
@@ -479,6 +481,7 @@ class CollectionAttributeImpl(AttributeImpl):
 
         collection = self.get_collection(state)
         collection.clear_with_event()
+        # TODO: catch key errors, convert to attributeerror?
         del state.dict[self.key]
 
     def initialize(self, state):
@@ -648,7 +651,7 @@ class ClassState(object):
         self.mappers = {}
         self.attrs = {}
         self.has_mutable_scalars = False
-        
+
 class InstanceState(object):
     """tracks state information at the instance level."""
 
@@ -658,7 +661,6 @@ class InstanceState(object):
         self.dict = obj.__dict__
         self.committed_state = {}
         self.modified = False
-        self.trigger = None
         self.callables = {}
         self.parents = {}
         self.pending = {}
@@ -735,7 +737,7 @@ class InstanceState(object):
             return None
             
     def __getstate__(self):
-        return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj()}
+        return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj(), 'expired_attributes':getattr(self, 'expired_attributes', None), 'callables':self.callables}
     
     def __setstate__(self, state):
         self.committed_state = state['committed_state']
@@ -745,43 +747,62 @@ class InstanceState(object):
         self.obj = weakref.ref(state['instance'])
         self.class_ = self.obj().__class__
         self.dict = self.obj().__dict__
-        self.callables = {}
-        self.trigger = None
-    
+        self.callables = state['callables']
+        self.runid = None
+        self.appenders = {}
+        if state['expired_attributes'] is not None:
+            self.expire_attributes(state['expired_attributes'])
+
     def initialize(self, key):
         getattr(self.class_, key).impl.initialize(self)
         
     def set_callable(self, key, callable_):
         self.dict.pop(key, None)
         self.callables[key] = callable_
-    
-    def __fire_trigger(self):
+
+    def __call__(self):
+        """__call__ allows the InstanceState to act as a deferred 
+        callable for loading expired attributes, which is also
+        serializable.
+        """
         instance = self.obj()
-        self.trigger(instance, [k for k in self.expired_attributes if k not in self.dict])
+        self.class_._class_state.deferred_scalar_loader(instance, [k for k in self.expired_attributes if k not in self.committed_state])
         for k in self.expired_attributes:
             self.callables.pop(k, None)
         self.expired_attributes.clear()
         return ATTR_WAS_SET
     
+    def unmodified(self):
+        """a set of keys which have no uncommitted changes"""
+
+        return util.Set([
+            attr.impl.key for attr in _managed_attributes(self.class_) if
+            attr.impl.key not in self.committed_state
+            and (not hasattr(attr.impl, 'commit_to_state') or not attr.impl.check_mutable_modified(self))
+        ])
+    unmodified = property(unmodified)
+    
     def expire_attributes(self, attribute_names):
         if not hasattr(self, 'expired_attributes'):
             self.expired_attributes = util.Set()
+            
         if attribute_names is None:
             for attr in _managed_attributes(self.class_):
                 self.dict.pop(attr.impl.key, None)
-                self.callables[attr.impl.key] = self.__fire_trigger
-                self.expired_attributes.add(attr.impl.key)
+
+                if attr.impl.accepts_scalar_loader:
+                    self.callables[attr.impl.key] = self
+                    self.expired_attributes.add(attr.impl.key)
+
             self.committed_state = {}
         else:
             for key in attribute_names:
                 self.dict.pop(key, None)
                 self.committed_state.pop(key, None)
 
-                if not getattr(self.class_, key).impl.accepts_global_callable:
-                    continue
-
-                self.callables[key] = self.__fire_trigger
-                self.expired_attributes.add(key)
+                if getattr(self.class_, key).impl.accepts_scalar_loader:
+                    self.callables[key] = self
+                    self.expired_attributes.add(key)
                 
     def reset(self, key):
         """remove the given attribute and any callables associated with it."""
@@ -1081,7 +1102,7 @@ def _init_class_state(class_):
     if not '_class_state' in class_.__dict__:
         class_._class_state = ClassState()
     
-def register_class(class_, extra_init=None, on_exception=None):
+def register_class(class_, extra_init=None, on_exception=None, deferred_scalar_loader=None):
     # do a sweep first, this also helps some attribute extensions
     # (like associationproxy) become aware of themselves at the 
     # class level
@@ -1089,6 +1110,7 @@ def register_class(class_, extra_init=None, on_exception=None):
         getattr(class_, key, None)
 
     _init_class_state(class_)
+    class_._class_state.deferred_scalar_loader=deferred_scalar_loader
     
     oldinit = None
     doinit = False
index 52e39372d8e0e21c4afb3832888c4a699bbb0b99..486b7b6b66e4706640236ce9de4b75a3c5f106ab 100644 (file)
@@ -15,6 +15,7 @@ ORM.
 """
 from sqlalchemy import util, logging, exceptions
 from sqlalchemy.sql import expression
+from itertools import chain
 class_mapper = None
 
 __all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension',
@@ -505,7 +506,27 @@ def build_path(mapper, key, prev=None):
         return prev + (mapper.base_mapper, key)
     else:
         return (mapper.base_mapper, key)
-        
+
+def serialize_path(path):
+    if path is None:
+        return None
+
+    return [
+        (mapper.class_, mapper.entity_name, key)
+        for mapper, key in [(path[i], path[i+1]) for i in range(0, len(path)-1, 2)]
+    ]
+    
+def deserialize_path(path):
+    if path is None:
+        return None
+
+    global class_mapper
+    if class_mapper is None:
+        from sqlalchemy.orm import class_mapper
+
+    return tuple(
+        chain(*[(class_mapper(cls, entity), key) for cls, entity, key in path])
+    )
 
 class MapperOption(object):
     """Describe a modification to a Query."""
index 8c375ea392c88f806d26604bd6d1721b7fbd8710..db666a4f99e0991aa37def7f9640581a19a187e2 100644 (file)
@@ -758,7 +758,7 @@ class Mapper(object):
         def on_exception(class_, oldinit, instance, args, kwargs):
             util.warn_exception(self.extension.init_failed, self, class_, oldinit, instance, args, kwargs)
 
-        attributes.register_class(self.class_, extra_init=extra_init, on_exception=on_exception)
+        attributes.register_class(self.class_, extra_init=extra_init, on_exception=on_exception, deferred_scalar_loader=_load_scalar_attributes)
         
         self._class_state = self.class_._class_state
         _mapper_registry[self] = True
@@ -1358,42 +1358,22 @@ class Mapper(object):
             instance._sa_session_id = context.session.hash_key
             session_identity_map[identitykey] = instance
         
-        if currentload or context.populate_existing or self.always_refresh or state.trigger:
+        if currentload or context.populate_existing or self.always_refresh:
             if isnew:
                 state.runid = context.runid
-                state.trigger = None
                 context.progress.add(state)
-
+                
             if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
                 self.populate_instance(context, instance, row, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew)
-        
+
+        elif getattr(state, 'expired_attributes', None):
+            if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=state.expired_attributes, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
+                self.populate_instance(context, instance, row, only_load_props=state.expired_attributes, instancekey=identitykey, isnew=isnew)
+            
         if result is not None and ('append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE):
             result.append(instance)
             
         return instance
-                
-    def _deferred_inheritance_condition(self, base_mapper, needs_tables):
-        def visit_binary(binary):
-            leftcol = binary.left
-            rightcol = binary.right
-            if leftcol is None or rightcol is None:
-                return
-            if leftcol.table not in needs_tables:
-                binary.left = sql.bindparam(None, None, type_=binary.right.type)
-                param_names.append((leftcol, binary.left))
-            elif rightcol not in needs_tables:
-                binary.right = sql.bindparam(None, None, type_=binary.right.type)
-                param_names.append((rightcol, binary.right))
-
-        allconds = []
-        param_names = []
-
-        for mapper in self.iterate_to_root():
-            if mapper is base_mapper:
-                break
-            allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
-        
-        return sql.and_(*allconds), param_names
 
     def translate_row(self, tomapper, row):
         """Translate the column keys of a row into a new or proxied
@@ -1451,7 +1431,10 @@ class Mapper(object):
             populators = new_populators
         else:
             populators = existing_populators
-                
+
+        if only_load_props:
+            populators = [p for p in populators if p[0] in only_load_props]
+            
         for (key, populator) in populators:
             selectcontext.exec_with_path(self, key, populator, instance, row, ispostselect=ispostselect, isnew=isnew, **flags)
             
@@ -1464,26 +1447,75 @@ class Mapper(object):
             p(state.obj())
 
     def _get_poly_select_loader(self, selectcontext, row):
-        # 'select' or 'union'+col not present
+        """set up attribute loaders for 'select' and 'deferred' polymorphic loading.
+        
+        this loading uses a second SELECT statement to load additional tables,
+        either immediately after loading the main table or via a deferred attribute trigger.
+        """
+        
         (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', self), (None, None))
-        if hosted_mapper is None or not needs_tables or hosted_mapper.polymorphic_fetch == 'deferred':
+        
+        if hosted_mapper is None or not needs_tables:
             return
         
         cond, param_names = self._deferred_inheritance_condition(hosted_mapper, needs_tables)
         statement = sql.select(needs_tables, cond, use_labels=True)
-        def post_execute(instance, **flags):
-            if self.__should_log_debug:
-                self.__log_debug("Post query loading instance " + instance_str(instance))
+        
+        if hosted_mapper.polymorphic_fetch == 'select':
+            def post_execute(instance, **flags):
+                if self.__should_log_debug:
+                    self.__log_debug("Post query loading instance " + instance_str(instance))
+
+                identitykey = self.identity_key_from_instance(instance)
+
+                params = {}
+                for c, bind in param_names:
+                    params[bind] = self._get_attr_by_column(instance, c)
+                row = selectcontext.session.connection(self).execute(statement, params).fetchone()
+                self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True)
+            return post_execute
+        elif hosted_mapper.polymorphic_fetch == 'deferred':
+            from sqlalchemy.orm.strategies import DeferredColumnLoader
+            
+            def post_execute(instance, **flags):
+                def create_statement(instance):
+                    params = {}
+                    for (c, bind) in param_names:
+                        # use the "committed" (database) version to get query column values
+                        params[bind] = self._get_committed_attr_by_column(instance, c)
+                    return (statement, params)
+                
+                props = [prop for prop in [self._get_col_to_prop(col) for col in statement.inner_columns] if prop.key not in instance.__dict__]
+                keys = [p.key for p in props]
+                for prop in props:
+                    strategy = prop._get_strategy(DeferredColumnLoader)
+                    instance._state.set_callable(prop.key, strategy.setup_loader(instance, props=keys, create_statement=create_statement))
+            return post_execute
+        else:
+            return None
+
+    def _deferred_inheritance_condition(self, base_mapper, needs_tables):
+        def visit_binary(binary):
+            leftcol = binary.left
+            rightcol = binary.right
+            if leftcol is None or rightcol is None:
+                return
+            if leftcol.table not in needs_tables:
+                binary.left = sql.bindparam(None, None, type_=binary.right.type)
+                param_names.append((leftcol, binary.left))
+            elif rightcol not in needs_tables:
+                binary.right = sql.bindparam(None, None, type_=binary.right.type)
+                param_names.append((rightcol, binary.right))
 
-            identitykey = self.identity_key_from_instance(instance)
+        allconds = []
+        param_names = []
 
-            params = {}
-            for c, bind in param_names:
-                params[bind] = self._get_attr_by_column(instance, c)
-            row = selectcontext.session.connection(self).execute(statement, params).fetchone()
-            self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True)
+        for mapper in self.iterate_to_root():
+            if mapper is base_mapper:
+                break
+            allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
 
-        return post_execute
+        return sql.and_(*allconds), param_names
             
 Mapper.logger = logging.class_logger(Mapper)
 
@@ -1501,6 +1533,16 @@ def has_mapper(object):
 
     return hasattr(object, '_entity_name')
 
+object_session = None
+
+def _load_scalar_attributes(instance, attribute_names):
+    global object_session
+    if not object_session:
+        from sqlalchemy.orm.session import object_session
+        
+    if object_session(instance).query(object_mapper(instance))._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None:
+        raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % instance_str(instance))
+
 def _state_mapper(state, entity_name=None):
     return state.class_._class_state.mappers[state.dict.get('_entity_name', entity_name)]
 
index 993eba4c5feeb868c2d256f5888120eeac1d2c98..f75d5c36c0e634e7c1b5729c009590bc0ee5be0c 100644 (file)
@@ -1113,7 +1113,7 @@ class Session(object):
         
         return util.IdentitySet(self.uow.new.values())
     new = property(new)
-    
+
 def _expire_state(state, attribute_names):
     """Standalone expire instance function.
 
@@ -1124,12 +1124,6 @@ def _expire_state(state, attribute_names):
     If the list is None or blank, the entire instance is expired.
     """
 
-    if state.trigger is None:
-        def load_attributes(instance, attribute_names):
-            if object_session(instance).query(instance.__class__)._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None:
-                raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
-        state.trigger = load_attributes
-
     state.expire_attributes(attribute_names)
 
 register_attribute = unitofwork.register_attribute
index 60fc0257906cb05f4e02ee54f50938433d2d6df7..33981f16145100efd469cf1a0cfbcb796c2b3756 100644 (file)
@@ -10,7 +10,7 @@ from sqlalchemy import sql, util, exceptions, logging
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.sql import visitors, expression, operators
 from sqlalchemy.orm import mapper, attributes
-from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption
+from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption, serialize_path, deserialize_path
 from sqlalchemy.orm import session as sessionlib
 from sqlalchemy.orm import util as mapperutil
 
@@ -80,53 +80,13 @@ class ColumnLoader(LoaderStrategy):
             if self._should_log_debug:
                 self.logger.debug("Returning active column fetcher for %s %s" % (mapper, self.key))
             return (new_execute, None, None)
-
-        # our mapped column is not present in the row.  check if we need to initialize a polymorphic
-        # row fetcher used by inheritance.
-        (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', mapper), (None, None))
-
-        if hosted_mapper is None:
-            return (None, None, None)
-        
-        if hosted_mapper.polymorphic_fetch == 'deferred':
-            # 'deferred' polymorphic row fetcher, put a callable on the property.
-            # create a deferred column loader which will query the remaining not-yet-loaded tables in an inheritance load.
-            # the mapper for the object creates the WHERE criterion using the mapper who originally 
-            # "hosted" the query and the list of tables which are unloaded between the "hosted" mapper
-            # and this mapper.  (i.e. A->B->C, the query used mapper A.  therefore will need B's and C's tables
-            # in the query).
-            
-            # deferred loader strategy
-            strategy = self.parent_property._get_strategy(DeferredColumnLoader)
-            
-            # full list of ColumnProperty objects to be loaded in the deferred fetch
-            props = [p.key for p in mapper.iterate_properties if isinstance(p.strategy, ColumnLoader) and p.columns[0].table in needs_tables]
-
-            # TODO: we are somewhat duplicating efforts from mapper._get_poly_select_loader 
-            # and should look for ways to simplify.
-            cond, param_names = mapper._deferred_inheritance_condition(hosted_mapper, needs_tables)
-            statement = sql.select(needs_tables, cond, use_labels=True)
-            def create_statement(instance):
-                params = {}
-                for (c, bind) in param_names:
-                    # use the "committed" (database) version to get query column values
-                    params[bind] = mapper._get_committed_attr_by_column(instance, c)
-                return (statement, params)
-            
+        else:
             def new_execute(instance, row, isnew, **flags):
                 if isnew:
-                    instance._state.set_callable(self.key, strategy.setup_loader(instance, props=props, create_statement=create_statement))
-                    
+                    instance._state.expire_attributes([self.key])
             if self._should_log_debug:
-                self.logger.debug("Returning deferred column fetcher for %s %s" % (mapper, self.key))
-                
+                self.logger.debug("Deferring load for %s %s" % (mapper, self.key))
             return (new_execute, None, None)
-        else:  
-            # immediate polymorphic row fetcher.  no processing needed for this row.
-            if self._should_log_debug:
-                self.logger.debug("Returning no column fetcher for %s %s" % (mapper, self.key))
-            return (None, None, None)
-
 
 ColumnLoader.logger = logging.class_logger(ColumnLoader)
 
@@ -170,9 +130,10 @@ class DeferredColumnLoader(LoaderStrategy):
             self.parent_property._get_strategy(ColumnLoader).setup_query(context, **kwargs)
         
     def setup_loader(self, instance, props=None, create_statement=None):
-        localparent = mapper.object_mapper(instance, raiseerror=False)
-        if localparent is None:
+        if not mapper.has_mapper(instance):
             return None
+            
+        localparent = mapper.object_mapper(instance)
 
         # adjust for the ColumnProperty associated with the instance
         # not being our own ColumnProperty.  This can occur when entity_name
@@ -181,39 +142,64 @@ class DeferredColumnLoader(LoaderStrategy):
         prop = localparent.get_property(self.key)
         if prop is not self.parent_property:
             return prop._get_strategy(DeferredColumnLoader).setup_loader(instance)
-            
-        def lazyload():
-            if not mapper.has_identity(instance):
-                return None
-            
-            if props is not None:
-                group = props
-            elif self.group is not None:
-                group = [p.key for p in localparent.iterate_properties if isinstance(p.strategy, DeferredColumnLoader) and p.group==self.group]
-            else:
-                group = [self.parent_property.key]
-            
-            # narrow the keys down to just those which aren't present on the instance
-            group = [k for k in group if k not in instance.__dict__]
-            
-            if self._should_log_debug:
-                self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), group and ','.join(group) or 'None'))
-
-            session = sessionlib.object_session(instance)
-            if session is None:
-                raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
 
-            if create_statement is None:
-                ident = instance._instance_key[1]
-                session.query(localparent)._get(None, ident=ident, only_load_props=group, refresh_instance=instance._state)
-            else:
-                statement, params = create_statement(instance)
-                session.query(localparent).from_statement(statement).params(params)._get(None, only_load_props=group, refresh_instance=instance._state)
-            return attributes.ATTR_WAS_SET
-        return lazyload
+        return LoadDeferredColumns(instance, self.key, props, optimizing_statement=create_statement)
                 
 DeferredColumnLoader.logger = logging.class_logger(DeferredColumnLoader)
 
+class LoadDeferredColumns(object):
+    """callable, serializable loader object used by DeferredColumnLoader"""
+    
+    def __init__(self, instance, key, keys, optimizing_statement):
+        self.instance = instance
+        self.key = key
+        self.keys = keys
+        self.optimizing_statement = optimizing_statement
+
+    def __getstate__(self):
+        return {'instance':self.instance, 'key':self.key, 'keys':self.keys}
+    
+    def __setstate__(self, state):
+        self.instance = state['instance']
+        self.key = state['key']
+        self.keys = state['keys']
+        self.optimizing_statement = None
+        
+    def __call__(self):
+        if not mapper.has_identity(self.instance):
+            return None
+            
+        localparent = mapper.object_mapper(self.instance, raiseerror=False)
+        
+        prop = localparent.get_property(self.key)
+        strategy = prop._get_strategy(DeferredColumnLoader)
+
+        if self.keys:
+            toload = self.keys
+        elif strategy.group:
+            toload = [p.key for p in localparent.iterate_properties if isinstance(p.strategy, DeferredColumnLoader) and p.group==strategy.group]
+        else:
+            toload = [self.key]
+
+        # narrow the keys down to just those which have no history
+        group = [k for k in toload if k in self.instance._state.unmodified]
+
+        if strategy._should_log_debug:
+            strategy.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(self.instance, self.key), group and ','.join(group) or 'None'))
+
+        session = sessionlib.object_session(self.instance)
+        if session is None:
+            raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (self.instance.__class__, self.key))
+
+        query = session.query(localparent)
+        if not self.optimizing_statement:
+            ident = self.instance._instance_key[1]
+            query._get(None, ident=ident, only_load_props=group, refresh_instance=self.instance._state)
+        else:
+            statement, params = self.optimizing_statement(self.instance)
+            query.from_statement(statement).params(params)._get(None, only_load_props=group, refresh_instance=self.instance._state)
+        return attributes.ATTR_WAS_SET
+
 class DeferredOption(StrategizedOption):
     def __init__(self, key, defer=False):
         super(DeferredOption, self).__init__(key)
@@ -276,7 +262,7 @@ NoLoader.logger = logging.class_logger(NoLoader)
 class LazyLoader(AbstractRelationLoader):
     def init(self):
         super(LazyLoader, self).init()
-        (self.lazywhere, self.lazybinds, self.lazyreverse) = self._create_lazy_clause(self)
+        (self.lazywhere, self.lazybinds, self.equated_columns) = self._create_lazy_clause(self)
         
         self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.lazywhere))
 
@@ -293,10 +279,10 @@ class LazyLoader(AbstractRelationLoader):
 
     def lazy_clause(self, instance, reverse_direction=False):
         if instance is None:
-            return self.lazy_none_clause(reverse_direction)
+            return self._lazy_none_clause(reverse_direction)
             
         if not reverse_direction:
-            (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.lazyreverse)
+            (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.equated_columns)
         else:
             (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
         bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
@@ -308,9 +294,9 @@ class LazyLoader(AbstractRelationLoader):
                 bindparam.value = mapper._get_committed_attr_by_column(instance, bind_to_col[bindparam.key])
         return visitors.traverse(criterion, clone=True, visit_bindparam=visit_bindparam)
     
-    def lazy_none_clause(self, reverse_direction=False):
+    def _lazy_none_clause(self, reverse_direction=False):
         if not reverse_direction:
-            (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.lazyreverse)
+            (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.equated_columns)
         else:
             (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
         bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
@@ -331,71 +317,18 @@ class LazyLoader(AbstractRelationLoader):
     def setup_loader(self, instance, options=None, path=None):
         if not mapper.has_mapper(instance):
             return None
-        else:
-            # adjust for the PropertyLoader associated with the instance
-            # not being our own PropertyLoader.  This can occur when entity_name
-            # mappers are used to map different versions of the same PropertyLoader
-            # to the class.
-            prop = mapper.object_mapper(instance).get_property(self.key)
-            if prop is not self.parent_property:
-                return prop._get_strategy(LazyLoader).setup_loader(instance)
-
-        def lazyload():
-            if self._should_log_debug:
-                self.logger.debug("lazy load attribute %s on instance %s" % (self.key, mapperutil.instance_str(instance)))
 
-            if not mapper.has_identity(instance):
-                return None
+        localparent = mapper.object_mapper(instance)
 
-            session = sessionlib.object_session(instance)
-            if session is None:
-                try:
-                    session = mapper.object_mapper(instance).get_session()
-                except exceptions.InvalidRequestError:
-                    raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
-
-            # if we have a simple straight-primary key load, use mapper.get()
-            # to possibly save a DB round trip
-            q = session.query(self.mapper).autoflush(False)
-            if path:
-                q = q._with_current_path(path)
-            if self.use_get:
-                params = {}
-                for col, bind in self.lazybinds.iteritems():
-                    # use the "committed" (database) version to get query column values
-                    params[bind.key] = self.parent._get_committed_attr_by_column(instance, col)
-                ident = []
-                nonnulls = False
-                for primary_key in self.select_mapper.primary_key: 
-                    bind = self.lazyreverse[primary_key]
-                    v = params[bind.key]
-                    if v is not None:
-                        nonnulls = True
-                    ident.append(v)
-                if not nonnulls:
-                    return None
-                if options:
-                    q = q._conditional_options(*options)
-                return q.get(ident)
-            elif self.order_by is not False:
-                q = q.order_by(self.order_by)
-            elif self.secondary is not None and self.secondary.default_order_by() is not None:
-                q = q.order_by(self.secondary.default_order_by())
-
-            if options:
-                q = q._conditional_options(*options)
-            q = q.filter(self.lazy_clause(instance))
-
-            result = q.all()
-            if self.uselist:
-                return result
-            else:
-                if result:
-                    return result[0]
-                else:
-                    return None
-
-        return lazyload
+        # adjust for the PropertyLoader associated with the instance
+        # not being our own PropertyLoader.  This can occur when entity_name
+        # mappers are used to map different versions of the same PropertyLoader
+        # to the class.
+        prop = localparent.get_property(self.key)
+        if prop is not self.parent_property:
+            return prop._get_strategy(LazyLoader).setup_loader(instance)
+        
+        return LoadLazyAttribute(instance, self.key, options, path)
 
     def create_row_processor(self, selectcontext, mapper, row):
         if not self.is_class_level or len(selectcontext.options):
@@ -424,7 +357,7 @@ class LazyLoader(AbstractRelationLoader):
         (primaryjoin, secondaryjoin, remote_side) = (prop.polymorphic_primaryjoin, prop.polymorphic_secondaryjoin, prop.remote_side)
         
         binds = {}
-        reverse = {}
+        equated_columns = {}
 
         def should_bind(targetcol, othercol):
             if reverse_direction and not secondaryjoin:
@@ -437,20 +370,17 @@ class LazyLoader(AbstractRelationLoader):
                 return
             leftcol = binary.left
             rightcol = binary.right
-            
+
+            equated_columns[rightcol] = leftcol
+            equated_columns[leftcol] = rightcol
+
             if should_bind(leftcol, rightcol):
-                col = leftcol
-                binary.left = binds.setdefault(leftcol,
-                        sql.bindparam(None, None, type_=binary.right.type))
-                reverse[rightcol] = binds[col]
+                binary.left = binds[leftcol] = sql.bindparam(None, None, type_=binary.right.type)
 
             # the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1",
             # which can happen in rare cases (test/orm/relationships.py RelationTest2)
             if leftcol is not rightcol and should_bind(rightcol, leftcol):
-                col = rightcol
-                binary.right = binds.setdefault(rightcol,
-                        sql.bindparam(None, None, type_=binary.left.type))
-                reverse[leftcol] = binds[col]
+                binary.right = binds[rightcol] = sql.bindparam(None, None, type_=binary.left.type)
 
         lazywhere = primaryjoin
         
@@ -461,11 +391,86 @@ class LazyLoader(AbstractRelationLoader):
             if reverse_direction:
                 secondaryjoin = visitors.traverse(secondaryjoin, clone=True, visit_binary=visit_binary)
             lazywhere = sql.and_(lazywhere, secondaryjoin)
-        return (lazywhere, binds, reverse)
+        return (lazywhere, binds, equated_columns)
     _create_lazy_clause = classmethod(_create_lazy_clause)
     
 LazyLoader.logger = logging.class_logger(LazyLoader)
 
+class LoadLazyAttribute(object):
+    """callable, serializable loader object used by LazyLoader"""
+
+    def __init__(self, instance, key, options, path):
+        self.instance = instance
+        self.key = key
+        self.options = options
+        self.path = path
+        
+    def __getstate__(self):
+        return {'instance':self.instance, 'key':self.key, 'options':self.options, 'path':serialize_path(self.path)}
+
+    def __setstate__(self, state):
+        self.instance = state['instance']
+        self.key = state['key']
+        self.options= state['options']
+        self.path = deserialize_path(state['path'])
+        
+    def __call__(self):
+        instance = self.instance
+        
+        if not mapper.has_identity(instance):
+            return None
+
+        instance_mapper = mapper.object_mapper(instance)
+        prop = instance_mapper.get_property(self.key)
+        strategy = prop._get_strategy(LazyLoader)
+        
+        if strategy._should_log_debug:
+            strategy.logger.debug("lazy load attribute %s on instance %s" % (self.key, mapperutil.instance_str(instance)))
+
+        session = sessionlib.object_session(instance)
+        if session is None:
+            try:
+                session = instance_mapper.get_session()
+            except exceptions.InvalidRequestError:
+                raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
+
+        q = session.query(prop.mapper).autoflush(False)
+        if self.path:
+            q = q._with_current_path(self.path)
+            
+        # if we have a simple primary key load, use mapper.get()
+        # to possibly save a DB round trip
+        if strategy.use_get:
+            ident = []
+            allnulls = True
+            for primary_key in prop.select_mapper.primary_key: 
+                val = instance_mapper._get_committed_attr_by_column(instance, strategy.equated_columns[primary_key])
+                allnulls = allnulls and val is None
+                ident.append(val)
+            if allnulls:
+                return None
+            if self.options:
+                q = q._conditional_options(*self.options)
+            return q.get(ident)
+            
+        if strategy.order_by is not False:
+            q = q.order_by(strategy.order_by)
+        elif strategy.secondary is not None and strategy.secondary.default_order_by() is not None:
+            q = q.order_by(strategy.secondary.default_order_by())
+
+        if self.options:
+            q = q._conditional_options(*self.options)
+        q = q.filter(strategy.lazy_clause(instance))
+
+        result = q.all()
+        if strategy.uselist:
+            return result
+        else:
+            if result:
+                return result[0]
+            else:
+                return None
+        
 
 class EagerLoader(AbstractRelationLoader):
     """Loads related objects inline with a parent query."""
@@ -630,8 +635,7 @@ class EagerLoader(AbstractRelationLoader):
             if self._should_log_debug:
                 self.logger.debug("eager loader %s degrading to lazy loader" % str(self))
             return self.parent_property._get_strategy(LazyLoader).create_row_processor(selectcontext, mapper, row)
-        
-            
+
     def __str__(self):
         return str(self.parent) + "." + self.key
         
index 7b76183be01647f9a4a7f9338dbd466c434c3dd4..397d99c0fb94a5baaf7c234d02a1cb1846f2b190 100644 (file)
@@ -284,8 +284,10 @@ def instance_str(instance):
 
 def state_str(state):
     """Return a string describing an instance."""
-
-    return state.class_.__name__ + "@" + hex(id(state.obj()))
+    if state is None:
+        return "None"
+    else:
+        return state.class_.__name__ + "@" + hex(id(state.obj()))
 
 def attribute_str(instance, attribute):
     return instance_str(instance) + "." + attribute
index ff370bc595b4253cea9cd8fc730ca8456be9d0d6..df3dbd279ab425c8cfd4ed865562dc4d1ddc356c 100644 (file)
@@ -47,7 +47,6 @@ __all__ = [
     'subquery', 'table', 'text', 'union', 'union_all', 'update', ]
 
 
-BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
 
 def desc(column):
     """Return a descending ``ORDER BY`` clause element.
@@ -1795,6 +1794,8 @@ class _TextClause(ClauseElement):
 
     __visit_name__ = 'textclause'
 
+    _bind_params_regex = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
+
     def __init__(self, text = "", bind=None, bindparams=None, typemap=None):
         self._bind = bind
         self.bindparams = {}
@@ -1809,7 +1810,7 @@ class _TextClause(ClauseElement):
 
         # scan the string and search for bind parameter names, add them
         # to the list of bindparams
-        self.text = BIND_PARAMS.sub(repl, text)
+        self.text = self._bind_params_regex.sub(repl, text)
         if bindparams is not None:
             for b in bindparams:
                 self.bindparams[b.key] = b
index 3748d3f34e240607d06f8935524369e76766b117..dd2bd8446322004cea79e94c79feee49fd6e9cca 100644 (file)
@@ -26,6 +26,7 @@ def suite():
         'orm.relationships',
         'orm.association',
         'orm.merge',
+        'orm.pickled',
         'orm.memusage',
         
         'orm.cycles',
index a756566d5f087f49249455b4830685a50aceb047..dd15e41e5694d55493f753f9ff3c5fdf047d1db3 100644 (file)
@@ -95,6 +95,65 @@ class AttributesTest(PersistTest):
         self.assert_(o4.mt2[0].a == 'abcde')
         self.assert_(o4.mt2[0].b is None)
 
+    def test_deferred(self):
+        class Foo(object):pass
+        
+        data = {'a':'this is a', 'b':12}
+        def loader(instance, keys):
+            for k in keys:
+                instance.__dict__[k] = data[k]
+            return attributes.ATTR_WAS_SET
+            
+        attributes.register_class(Foo, deferred_scalar_loader=loader)
+        attributes.register_attribute(Foo, 'a', uselist=False, useobject=False)
+        attributes.register_attribute(Foo, 'b', uselist=False, useobject=False)
+        
+        f = Foo()
+        f._state.expire_attributes(None)
+        self.assertEquals(f.a, "this is a")
+        self.assertEquals(f.b, 12)
+        
+        f.a = "this is some new a"
+        f._state.expire_attributes(None)
+        self.assertEquals(f.a, "this is a")
+        self.assertEquals(f.b, 12)
+
+        f._state.expire_attributes(None)
+        f.a = "this is another new a"
+        self.assertEquals(f.a, "this is another new a")
+        self.assertEquals(f.b, 12)
+
+        f._state.expire_attributes(None)
+        self.assertEquals(f.a, "this is a")
+        self.assertEquals(f.b, 12)
+
+        del f.a
+        self.assertEquals(f.a, None)
+        self.assertEquals(f.b, 12)
+        
+        f._state.commit_all()
+        self.assertEquals(f.a, None)
+        self.assertEquals(f.b, 12)
+
+    def test_deferred_pickleable(self):
+        data = {'a':'this is a', 'b':12}
+        def loader(instance, keys):
+            for k in keys:
+                instance.__dict__[k] = data[k]
+            return attributes.ATTR_WAS_SET
+            
+        attributes.register_class(MyTest, deferred_scalar_loader=loader)
+        attributes.register_attribute(MyTest, 'a', uselist=False, useobject=False)
+        attributes.register_attribute(MyTest, 'b', uselist=False, useobject=False)
+        
+        m = MyTest()
+        m._state.expire_attributes(None)
+        assert 'a' not in m.__dict__
+        m2 = pickle.loads(pickle.dumps(m))
+        assert 'a' not in m2.__dict__
+        self.assertEquals(m2.a, "this is a")
+        self.assertEquals(m2.b, 12)
+        
     def test_list(self):
         class User(object):pass
         class Address(object):pass
@@ -860,7 +919,6 @@ class HistoryTest(PersistTest):
         self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [], []))
 
         lazy_load = [bar1, bar2, bar3]
-        f._state.trigger = lazyload(f)
         f._state.expire_attributes(['bars'])
         self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar2, bar3], []))
         
index be9c881c0c87c95cb78afbe567a44f7a6d05e4cd..e1c0166120e7b845be0ee444bed76c61ec6de58a 100644 (file)
@@ -31,8 +31,6 @@ class ExpireTest(FixtureTest):
         self.assert_sql_count(testbase.db, go, 1)
         assert 'name' in u.__dict__
 
-        # we're changing the database here, so if this test fails in the middle,
-        # it'll screw up the other tests which are hardcoded to 7/'jack'
         u.name = 'foo'
         sess.flush()
         # change the value in the DB
@@ -45,9 +43,9 @@ class ExpireTest(FixtureTest):
         # test that it refreshed
         assert u.__dict__['name'] == 'jack'
 
-        # object should be back to normal now,
-        # this should *not* produce a SELECT statement (not tested here though....)
-        assert u.name == 'jack'
+        def go():
+            assert u.name == 'jack'
+        self.assert_sql_count(testbase.db, go, 0)
     
     def test_expire_doesntload_on_set(self):
         mapper(User, users)
@@ -76,6 +74,15 @@ class ExpireTest(FixtureTest):
             assert o.isopen == 1
         self.assert_sql_count(testbase.db, go, 1)
         assert o.description == 'order 3 modified'
+
+        del o.description
+        assert "description" not in o.__dict__
+        sess.expire(o, ['isopen'])
+        sess.query(Order).all()
+        assert o.isopen == 1
+        assert "description" not in o.__dict__
+
+        assert o.description is None
         
     def test_expire_committed(self):
         """test that the committed state of the attribute receives the most recent DB data"""
@@ -144,11 +151,16 @@ class ExpireTest(FixtureTest):
         def go():
             assert u.addresses[0].email_address == 'jack@bean.com'
             assert u.name == 'jack'
-        # one load
-        self.assert_sql_count(testbase.db, go, 1)
+        # two loads, since relation() + scalar are 
+        # separate right now
+        self.assert_sql_count(testbase.db, go, 2)
         assert 'name' in u.__dict__
         assert 'addresses' in u.__dict__
 
+        sess.expire(u, ['name', 'addresses'])
+        assert 'name' not in u.__dict__
+        assert 'addresses' not in u.__dict__
+
     def test_partial_expire(self):
         mapper(Order, orders)
 
@@ -380,6 +392,9 @@ class RefreshTest(FixtureTest):
         s.expire(u)
 
         # get the attribute, it refreshes
+        print "OK------"
+#        print u.__dict__
+#        print u._state.callables
         assert u.name == 'jack'
         assert id(a) not in [id(x) for x in u.addresses]
 
index 662ac4a29f249c5aa7e68a3ab17471a810b9d772..a7cb3a57d31d07a594c54b87ce74194e5917ee3b 100644 (file)
@@ -87,7 +87,9 @@ class MapperTest(MapperSuperTest):
         a = s.query(Address).from_statement(select([addresses.c.address_id, addresses.c.user_id])).first()
         assert a.user_id == 7
         assert a.address_id == 1
-        assert a.email_address is None
+        # email address auto-defers
+        assert 'email_addres' not in a.__dict__
+        assert a.email_address == 'jack@bean.com'
 
     def test_badconstructor(self):
         """test that if the construction of a mapped class fails, the instnace does not get placed in the session"""
diff --git a/test/orm/pickled.py b/test/orm/pickled.py
new file mode 100644 (file)
index 0000000..eac6280
--- /dev/null
@@ -0,0 +1,119 @@
+import testbase
+from sqlalchemy import *
+from sqlalchemy import exceptions
+from sqlalchemy.orm import *
+from testlib import *
+from testlib.fixtures import *
+import pickle
+
+class EmailUser(User):
+    pass
+    
+class PickleTest(FixtureTest):
+    keep_mappers = False
+    keep_data = False
+    
+    def test_transient(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref="user")
+        })
+        mapper(Address, addresses)
+        
+        sess = create_session()
+        u1 = User(name='ed')
+        u1.addresses.append(Address(email_address='ed@bar.com'))
+        
+        u2 = pickle.loads(pickle.dumps(u1))
+        sess.save(u2)
+        sess.flush()
+        
+        sess.clear()
+        
+        self.assertEquals(u1, sess.query(User).get(u2.id))
+    
+    def test_class_deferred_cols(self):
+        mapper(User, users, properties={
+            'name':deferred(users.c.name),
+            'addresses':relation(Address, backref="user")
+        })
+        mapper(Address, addresses, properties={
+            'email_address':deferred(addresses.c.email_address)
+        })
+        sess = create_session()
+        u1 = User(name='ed')
+        u1.addresses.append(Address(email_address='ed@bar.com'))
+        sess.save(u1)
+        sess.flush()
+        sess.clear()
+        u1 = sess.query(User).get(u1.id)
+        assert 'name' not in u1.__dict__
+        assert 'addresses' not in u1.__dict__
+        
+        u2 = pickle.loads(pickle.dumps(u1))
+        sess2 = create_session()
+        sess2.update(u2)
+        self.assertEquals(u2.name, 'ed')
+        self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
+        
+    def test_instance_deferred_cols(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref="user")
+        })
+        mapper(Address, addresses)
+        
+        sess = create_session()
+        u1 = User(name='ed')
+        u1.addresses.append(Address(email_address='ed@bar.com'))
+        sess.save(u1)
+        sess.flush()
+        sess.clear()
+        
+        u1 = sess.query(User).options(defer('name'), defer('addresses.email_address')).get(u1.id)
+        assert 'name' not in u1.__dict__
+        assert 'addresses' not in u1.__dict__
+        
+        u2 = pickle.loads(pickle.dumps(u1))
+        sess2 = create_session()
+        sess2.update(u2)
+        self.assertEquals(u2.name, 'ed')
+        assert 'addresses' not in u1.__dict__
+        ad = u2.addresses[0]
+        assert 'email_address' not in ad.__dict__
+        self.assertEquals(ad.email_address, 'ed@bar.com')
+        self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
+
+class PolymorphicDeferredTest(ORMTest):
+    def define_tables(self, metadata):
+        global users, email_users
+        users = Table('users', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('name', String(30)),
+            Column('type', String(30)),
+            )
+        email_users = Table('email_users', metadata,
+            Column('id', Integer, ForeignKey('users.id'), primary_key=True),
+            Column('email_address', String(30))
+            )
+            
+    def test_polymorphic_deferred(self):
+        mapper(User, users, polymorphic_identity='user', polymorphic_on=users.c.type, polymorphic_fetch='deferred')
+        mapper(EmailUser, email_users, inherits=User, polymorphic_identity='emailuser')
+        
+        eu = EmailUser(name="user1", email_address='foo@bar.com')
+        sess = create_session()
+        sess.save(eu)
+        sess.flush()
+        sess.clear()
+        
+        eu = sess.query(User).first()
+        eu2 = pickle.loads(pickle.dumps(eu))
+        sess2 = create_session()
+        sess2.update(eu2)
+        assert 'email_address' not in eu2.__dict__
+        self.assertEquals(eu2.email_address, 'foo@bar.com')
+        
+        
+        
+
+if __name__ == '__main__':
+    testbase.main()
\ No newline at end of file