]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- session.refresh() and session.expire() now support an additional argument
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 18 Nov 2007 02:13:56 +0000 (02:13 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 18 Nov 2007 02:13:56 +0000 (02:13 +0000)
"attribute_names", a list of individual attribute keynames to be refreshed
or expired, allowing partial reloads of attributes on an already-loaded
instance.
- finally simplified the behavior of deferred attributes, deferred polymorphic
load, session.refresh, session.expire, mapper._postfetch to all use a single
codepath through query._get(), which now supports a list of individual attribute names
to be refreshed.  the *one* exception still remaining is mapper._get_poly_select_loader(),
which may stay that way since its inline with an already processing load operation.
otherwise, query._get() is the single place that all "load this instance's row" operation
proceeds.
- cleanup all over the place

20 files changed:
CHANGES
doc/build/content/session.txt
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/alltests.py
test/orm/attributes.py
test/orm/dynamic.py
test/orm/eager_relations.py
test/orm/expire.py [new file with mode: 0644]
test/orm/inheritance/basic.py
test/orm/lazy_relations.py
test/orm/mapper.py
test/orm/query.py
test/orm/unitofwork.py
test/testlib/fixtures.py

diff --git a/CHANGES b/CHANGES
index d2a82953eb6d06cb8ace90efc33ad9aadb78a544..7fc24cfd660547707c12efca312a5d9fe588b267 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -76,6 +76,11 @@ CHANGES
     columns to the result set.  This eliminates a JOIN from all eager loads
     with LIMIT/OFFSET.  [ticket:843]
 
+  - session.refresh() and session.expire() now support an additional argument
+    "attribute_names", a list of individual attribute keynames to be refreshed
+    or expired, allowing partial reloads of attributes on an already-loaded 
+    instance.
+    
   - Mapped classes may now define __eq__, __hash__, and __nonzero__ methods
     with arbitrary sementics.  The orm now handles all mapped instances on
     an identity-only basis. (e.g. 'is' vs '==') [ticket:676]
index 804769a943403366dea563326f83d8680d00ea3a..7698a1b5cf3abafa689bb4cd37aaf793c642f062 100644 (file)
@@ -396,6 +396,18 @@ To assist with the Session's "sticky" behavior of instances which are present, i
     session.expire(obj1)
     session.expire(obj2)
 
+`refresh()` and `expire()` also support being passed a list of individual attribute names in which to be refreshed.  These names can reference any attribute, column-based or relation based:
+
+    {python}
+    # immediately re-load the attributes 'hello', 'world' on obj1, obj2
+    session.refresh(obj1, ['hello', 'world'])
+    session.refresh(obj2, ['hello', 'world'])
+    
+    # expire the attriibutes 'hello', 'world' objects obj1, obj2, attributes will be reloaded
+    # on the next access:
+    session.expire(obj1, ['hello', 'world'])
+    session.expire(obj2, ['hello', 'world'])
+
 ## Cascades
 
 Mappers support the concept of configurable *cascade* behavior on `relation()`s.  This behavior controls how the Session should treat the instances that have a parent-child relationship with another instance that is operated upon by the Session.  Cascade is indicated as a comma-separated list of string keywords, with the possible values `all`, `delete`, `save-update`, `refresh-expire`, `merge`, `expunge`, and `delete-orphan`.
index 5e22bc7628a3cfeaccbd4ab255c4c1da69350eb2..7c760f15a89f4685e05bd80e4b0e6d9bd6e18533 100644 (file)
@@ -177,41 +177,16 @@ class AttributeImpl(object):
         if callable_ is None:
             self.initialize(state)
         else:
-            state.callables[self] = callable_
+            state.callables[self.key] = callable_
 
     def _get_callable(self, state):
-        if self in state.callables:
-            return state.callables[self]
+        if self.key in state.callables:
+            return state.callables[self.key]
         elif self.callable_ is not None:
             return self.callable_(state.obj())
         else:
             return None
 
-    def reset(self, state):
-        """Remove any per-instance callable functions corresponding to
-        this ``InstrumentedAttribute``'s attribute from the given
-        object, and remove this ``InstrumentedAttribute``'s attribute
-        from the given object's dictionary.
-        """
-
-        try:
-            del state.callables[self]
-        except KeyError:
-            pass
-        self.clear(state)
-
-    def clear(self, state):
-        """Remove this ``InstrumentedAttribute``'s attribute from the given object's dictionary.
-
-        Subsequent calls to ``getattr(obj, key)`` will raise an
-        ``AttributeError`` by default.
-        """
-
-        try:
-            del state.dict[self.key]
-        except KeyError:
-            pass
-
     def check_mutable_modified(self, state):
         return False
 
@@ -232,11 +207,6 @@ class AttributeImpl(object):
         try:
             return state.dict[self.key]
         except KeyError:
-            # if an instance-wide "trigger" was set, call that
-            # and start again
-            if state.trigger:
-                state.call_trigger()
-                return self.get(state, passive=passive)
 
             callable_ = self._get_callable(state)
             if callable_ is not None:
@@ -246,6 +216,8 @@ class AttributeImpl(object):
                 if value is not ATTR_WAS_SET:
                     return self.set_committed_value(state, value)
                 else:
+                    if self.key not in state.dict:
+                        return self.get(state, passive=passive)
                     return state.dict[self.key]
             else:
                 # Return a new, empty value
@@ -278,10 +250,6 @@ class AttributeImpl(object):
         state.dict[self.key] = value
         return value
 
-    def set_raw_value(self, state, value):
-        state.dict[self.key] = value
-        return value
-
     def fire_append_event(self, state, value, initiator):
         state.modified = True
         if self.trackparent and value is not None:
@@ -318,7 +286,8 @@ class ScalarAttributeImpl(AttributeImpl):
         if copy_function is None:
             copy_function = self.__copy
         self.copy = copy_function
-
+        self.accepts_global_callable = True
+        
     def __copy(self, item):
         # scalar values are assumed to be immutable unless a copy function
         # is passed
@@ -350,10 +319,6 @@ class ScalarAttributeImpl(AttributeImpl):
         if initiator is self:
             return
 
-        # if an instance-wide "trigger" was set, call that
-        if state.trigger:
-            state.call_trigger()
-
         state.dict[self.key] = value
         state.modified=True
 
@@ -372,6 +337,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
           compare_function=compare_function, mutable_scalars=mutable_scalars, **kwargs)
         if compare_function is None:
             self.is_equal = identity_equal
+        self.accepts_global_callable = False
 
     def delete(self, state):
         old = self.get(state)
@@ -389,10 +355,6 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         if initiator is self:
             return
 
-        # if an instance-wide "trigger" was set, call that
-        if state.trigger:
-            state.call_trigger()
-
         old = self.get(state)
         state.dict[self.key] = value
         self.fire_replace_event(state, value, old, initiator)
@@ -416,6 +378,8 @@ class CollectionAttributeImpl(AttributeImpl):
             copy_function = self.__copy
         self.copy = copy_function
 
+        self.accepts_global_callable = False
+
         if typecallable is None:
             typecallable = list
         self.collection_factory = \
@@ -478,10 +442,6 @@ class CollectionAttributeImpl(AttributeImpl):
         elif setting_type == dict:
             value = value.values()
 
-        # if an instance-wide "trigger" was set, call that
-        if state.trigger:
-            state.call_trigger()
-
         old = self.get(state)
         old_collection = self.get_collection(state, old)
 
@@ -583,13 +543,13 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
 class InstanceState(object):
     """tracks state information at the instance level."""
 
-    __slots__ = 'class_', 'obj', 'dict', 'committed_state', 'modified', 'trigger', 'callables', 'parents', 'instance_dict', '_strong_obj'
+    __slots__ = 'class_', 'obj', 'dict', 'committed_state', 'modified', 'trigger', 'callables', 'parents', 'instance_dict', '_strong_obj', 'expired_attributes'
     
     def __init__(self, obj):
         self.class_ = obj.__class__
         self.obj = weakref.ref(obj, self.__cleanup)
         self.dict = obj.__dict__
-        self.committed_state = None
+        self.committed_state = {}
         self.modified = False
         self.trigger = None
         self.callables = {}
@@ -649,36 +609,75 @@ class InstanceState(object):
         self.dict = self.obj().__dict__
         self.callables = {}
         self.trigger = None
+    
+    def initialize(self, key):
+        getattr(self.class_, key).impl.initialize(self)
         
-    def call_trigger(self):
-        trig = self.trigger
-        self.trigger = None
-        trig()
+    def set_callable(self, key, callable_):
+        self.dict.pop(key, None)
+        self.callables[key] = callable_
+
+    def __fire_trigger(self):
+        self.trigger(self.obj(), self.expired_attributes)
+        for k in self.expired_attributes:
+            self.callables.pop(k, None)
+        self.expired_attributes.clear()
+        return ATTR_WAS_SET
+    
+    def expire_attributes(self, attribute_names):
+        if not hasattr(self, 'expired_attributes'):
+            self.expired_attributes = util.Set()
+        if attribute_names is None:
+            for attr in self.class_._sa_attribute_manager.managed_attributes(self.class_):
+                self.dict.pop(attr.impl.key, None)
+                self.callables[attr.impl.key] = self.__fire_trigger
+                self.expired_attributes.add(attr.impl.key)
+        else:
+            for key in attribute_names:
+                self.dict.pop(key, None)
+
+                if not getattr(self.class_, key).impl.accepts_global_callable:
+                    continue
+
+                self.callables[key] = self.__fire_trigger
+                self.expired_attributes.add(key)
+                
+    def reset(self, key):
+        """remove the given attribute and any callables associated with it."""
+        
+        self.dict.pop(key, None)
+        self.callables.pop(key, None)
+        
+    def clear(self):
+        """clear all attributes from the instance."""
+        
+        for attr in self.class_._sa_attribute_manager.managed_attributes(self.class_):
+            self.dict.pop(attr.impl.key, None)
+    
+    def commit(self, keys):
+        """commit all attributes named in the given list of key names.
+        
+        This is used by a partial-attribute load operation to mark committed those attributes
+        which were refreshed from the database.
+        """
+        
+        for key in keys:
+            getattr(self.class_, key).impl.commit_to_state(self)
+            
+    def commit_all(self):
+        """commit all attributes unconditionally.
+        
+        This is used after a flush() or a regular instance load or refresh operation
+        to mark committed all populated attributes.
+        """
         
-    def commit(self, manager, obj):
         self.committed_state = {}
         self.modified = False
-        for attr in manager.managed_attributes(obj.__class__):
+        for attr in self.class_._sa_attribute_manager.managed_attributes(self.class_):
             attr.impl.commit_to_state(self)
         # remove strong ref
         self._strong_obj = None
         
-    def rollback(self, manager, obj):
-        if not self.committed_state:
-            manager._clear(obj)
-        else:
-            for attr in manager.managed_attributes(obj.__class__):
-                if attr.impl.key in self.committed_state:
-                    if not hasattr(attr.impl, 'get_collection'):
-                        obj.__dict__[attr.impl.key] = self.committed_state[attr.impl.key]
-                    else:
-                        collection = attr.impl.get_collection(self)
-                        collection.clear_without_event()
-                        for item in self.committed_state[attr.impl.key]:
-                            collection.append_without_event(item)
-                else:
-                    if attr.impl.key in self.dict:
-                        del self.dict[attr.impl.key]
 
 class InstanceDict(UserDict.UserDict):
     """similar to WeakValueDictionary, but wired towards 'state' objects."""
@@ -878,28 +877,6 @@ class AttributeManager(object):
     def clear_attribute_cache(self):
         self._attribute_cache.clear()
 
-    def rollback(self, *obj):
-        """Retrieve the committed history for each object in the given
-        list, and rolls back the attributes each instance to their
-        original value.
-        """
-
-        for o in obj:
-            o._state.rollback(self, o)
-
-    def _clear(self, obj):
-        for attr in self.managed_attributes(obj.__class__):
-            try:
-                del obj.__dict__[attr.impl.key]
-            except KeyError:
-                pass
-    
-    def commit(self, *obj):
-        """Establish the "committed state" for each object in the given list."""
-
-        for o in obj:
-            o._state.commit(self, o)
-
     def managed_attributes(self, class_):
         """Return a list of all ``InstrumentedAttribute`` objects
         associated with the given class.
@@ -970,60 +947,9 @@ class AttributeManager(object):
         else:
             return [x]
 
-    def trigger_history(self, obj, callable):
-        """Clear all managed object attributes and places the given
-        `callable` as an attribute-wide *trigger*, which will execute
-        upon the next attribute access, after which the trigger is
-        removed.
-        """
-
-        s = obj._state
-        self._clear(obj)
-        s.committed_state = None
-        s.trigger = callable
-
-    def untrigger_history(self, obj):
-        """Remove a trigger function set by trigger_history.
-
-        Does not restore the previous state of the object.
-        """
-
-        obj._state.trigger = None
-
-    def has_trigger(self, obj):
-        """Return True if the given object has a trigger function set
-        by ``trigger_history()``.
-        """
-
-        return obj._state.trigger is not None
-
-    def reset_instance_attribute(self, obj, key):
-        """Remove any per-instance callable functions corresponding to
-        given attribute `key` from the given object, and remove this
-        attribute from the given object's dictionary.
-        """
-
-        attr = getattr(obj.__class__, key)
-        attr.impl.reset(obj._state)
-
-    def is_class_managed(self, class_, key):
-        """Return True if the given `key` correponds to an
-        instrumented property on the given class.
-        """
-        return hasattr(class_, key) and isinstance(getattr(class_, key), InstrumentedAttribute)
-
     def has_parent(self, class_, obj, key, optimistic=False):
         return getattr(class_, key).impl.hasparent(obj._state, optimistic=optimistic)
 
-    def init_instance_attribute(self, obj, key, callable_=None, clear=False):
-        """Initialize an attribute on an instance to either a blank
-        value, cancelling out any class- or instance-level callables
-        that were present, or if a `callable` is supplied set the
-        callable to be invoked when the attribute is next accessed.
-        """
-
-        getattr(obj.__class__, key).impl.set_callable(obj._state, callable_, clear=clear)
-
     def _create_prop(self, class_, key, uselist, callable_, typecallable, useobject, **kwargs):
         """Create a scalar property object, defaulting to
         ``InstrumentedAttribute``, which will communicate change
@@ -1135,12 +1061,6 @@ class AttributeManager(object):
         setattr(class_, key, InstrumentedAttribute(self._create_prop(class_, key, uselist, callable_, useobject=useobject,
                                            typecallable=typecallable, **kwargs), comparator=comparator))
 
-    def set_raw_value(self, instance, key, value):
-        getattr(instance.__class__, key).impl.set_raw_value(instance._state, value)
-
-    def set_committed_value(self, instance, key, value):
-        getattr(instance.__class__, key).impl.set_committed_value(instance._state, value)
-
     def init_collection(self, instance, key):
         """Initialize a collection attribute and return the collection adapter."""
         attr = getattr(instance.__class__, key).impl
index f523d0706bb9ea0462d678412e318b030ac4fd04..a4a744c4731d4608f4c110d134f88a7c8542729c 100644 (file)
@@ -1178,7 +1178,7 @@ class Mapper(object):
                 prop = self._getpropbycolumn(c, raiseerror=False)
                 if prop is None:
                     continue
-                deferred_props.append(prop)
+                deferred_props.append(prop.key)
                 continue
             if c.primary_key or not c.key in params:
                 continue
@@ -1189,7 +1189,7 @@ class Mapper(object):
                 self.set_attr_by_column(obj, c, params[c.key])
         
         if deferred_props:
-            deferred_load(obj, props=deferred_props)
+            expire_instance(obj, deferred_props)
 
     def delete_obj(self, objects, uowtransaction):
         """Issue ``DELETE`` statements for a list of objects.
@@ -1342,7 +1342,7 @@ class Mapper(object):
 
         return self.__surrogate_mapper or self
 
-    def _instance(self, context, row, result = None, skip_polymorphic=False):
+    def _instance(self, context, row, result=None, skip_polymorphic=False, extension=None, only_load_props=None, refresh_instance=None):
         """Pull an object instance from the given row and append it to
         the given result list.
 
@@ -1351,36 +1351,35 @@ class Mapper(object):
         on the instance to also process extra information in the row.
         """
 
-        # apply ExtensionOptions applied to the Query to this mapper,
-        # but only if our mapper matches.
-        # TODO: what if our mapper inherits from the mapper (i.e. as in a polymorphic load?)
-        if context.mapper is self:
-            extension = context.extension
-        else:
+        if not extension:
             extension = self.extension
-
+            
         if 'translate_row' in extension.methods:
             ret = extension.translate_row(self, context, row)
             if ret is not EXT_CONTINUE:
                 row = ret
 
-        if not skip_polymorphic and self.polymorphic_on is not None:
-            discriminator = row[self.polymorphic_on]
-            if discriminator is not None:
-                mapper = self.polymorphic_map[discriminator]
-                if mapper is not self:
-                    if ('polymorphic_fetch', mapper) not in context.attributes:
-                        context.attributes[('polymorphic_fetch', mapper)] = (self, [t for t in mapper.tables if t not in self.tables])
-                    row = self.translate_row(mapper, row)
-                    return mapper._instance(context, row, result=result, skip_polymorphic=True)
+        if refresh_instance is None:
+            if not skip_polymorphic and self.polymorphic_on is not None:
+                discriminator = row[self.polymorphic_on]
+                if discriminator is not None:
+                    mapper = self.polymorphic_map[discriminator]
+                    if mapper is not self:
+                        if ('polymorphic_fetch', mapper) not in context.attributes:
+                            context.attributes[('polymorphic_fetch', mapper)] = (self, [t for t in mapper.tables if t not in self.tables])
+                        row = self.translate_row(mapper, row)
+                        return mapper._instance(context, row, result=result, skip_polymorphic=True)
         
-        # look in main identity map.  if its there, we dont do anything to it,
-        # including modifying any of its related items lists, as its already
-        # been exposed to being modified by the application.
 
-        identitykey = self.identity_key_from_row(row)
+        # determine identity key 
+        if refresh_instance:
+            identitykey = refresh_instance._instance_key
+        else:
+            identitykey = self.identity_key_from_row(row)
         (session_identity_map, local_identity_map) = (context.session.identity_map, context.identity_map)
-        
+
+        # look in main identity map.  if present, we only populate
+        # if repopulate flags are set.  this block returns the instance.
         if identitykey in session_identity_map:
             instance = session_identity_map[identitykey]
 
@@ -1397,19 +1396,21 @@ class Mapper(object):
                 if identitykey not in local_identity_map:
                     local_identity_map[identitykey] = instance
                     isnew = True
-                if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
-                    self.populate_instance(context, instance, row, instancekey=identitykey, isnew=isnew)
+                if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, instancekey=identitykey, isnew=isnew, only_load_props=only_load_props) is EXT_CONTINUE:
+                    self.populate_instance(context, instance, row, instancekey=identitykey, isnew=isnew, only_load_props=only_load_props)
 
             if 'append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
                 if result is not None:
                     result.append(instance)
+            
             return instance
-        else:
-            if self.__should_log_debug:
-                self.__log_debug("_instance(): identity key %s not in session" % str(identitykey))
+            
+        elif self.__should_log_debug:
+            self.__log_debug("_instance(): identity key %s not in session" % str(identitykey))
                 
-        # look in result-local identitymap for it.
+        # look in identity map which is local to this load operation
         if identitykey not in local_identity_map:
+            # check that sufficient primary key columns are present
             if self.allow_null_pks:
                 # check if *all* primary key cols in the result are None - this indicates
                 # an instance of the object is not present in the row.
@@ -1424,7 +1425,6 @@ class Mapper(object):
                 if None in identitykey[1]:
                     return None
 
-            # plugin point
             if 'create_instance' in extension.methods:
                 instance = extension.create_instance(self, context, row, self.class_)
                 if instance is EXT_CONTINUE:
@@ -1433,24 +1433,26 @@ class Mapper(object):
                 instance = attribute_manager.new_instance(self.class_)
                 
             instance._entity_name = self.entity_name
+            instance._instance_key = identitykey
+
             if self.__should_log_debug:
                 self.__log_debug("_instance(): created new instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey)))
+
             local_identity_map[identitykey] = instance
             isnew = True
         else:
+            # instance is already present
             instance = local_identity_map[identitykey]
             isnew = False
 
-        # call further mapper properties on the row, to pull further
-        # instances from the row and possibly populate this item.
+        # populate.  note that we still call this for an instance already loaded as additional collection state is present
+        # in subsequent rows (i.e. eagerly loaded collections)
         flags = {'instancekey':identitykey, 'isnew':isnew}
-        if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, **flags) is EXT_CONTINUE:
-            self.populate_instance(context, instance, row, **flags)
+        if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, **flags) is EXT_CONTINUE:
+            self.populate_instance(context, instance, row, only_load_props=only_load_props, **flags)
         if 'append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, **flags) is EXT_CONTINUE:
             if result is not None:
                 result.append(instance)
-                
-        instance._instance_key = identitykey
         
         return instance
 
@@ -1487,13 +1489,14 @@ class Mapper(object):
         """
         
         if tomapper in self._row_translators:
+            # row translators are cached based on target mapper
             return self._row_translators[tomapper](row)
         else:
             translator = create_row_adapter(self.mapped_table, tomapper.mapped_table, equivalent_columns=self._equivalent_columns)
             self._row_translators[tomapper] = translator
             return translator(row)
 
-    def populate_instance(self, selectcontext, instance, row, ispostselect=None, isnew=False, **flags):
+    def populate_instance(self, selectcontext, instance, row, ispostselect=None, isnew=False, only_load_props=None, **flags):
         """populate an instance from a result row."""
 
         snapshot = selectcontext.path + (self,)
@@ -1511,6 +1514,8 @@ class Mapper(object):
             existing_populators = []
             post_processors = []
             for prop in self.__props.values():
+                if only_load_props and prop.key not in only_load_props:
+                    continue
                 (newpop, existingpop, post_proc) = selectcontext.exec_with_path(self, prop.key, prop.create_row_processor, selectcontext, self, row)
                 if newpop is not None:
                     new_populators.append((prop.key, newpop))
@@ -1518,7 +1523,8 @@ class Mapper(object):
                     existing_populators.append((prop.key, existingpop))
                 if post_proc is not None:
                     post_processors.append(post_proc)
-                
+            
+            # install a post processor for immediate post-load of joined-table inheriting mappers
             poly_select_loader = self._get_poly_select_loader(selectcontext, row)
             if poly_select_loader is not None:
                 post_processors.append(poly_select_loader)
index f2a60315c142223a1c09c4706d19d5d50c53d42c..df52512eed263e01fb046bf3f847b934f0cade4a 100644 (file)
@@ -599,7 +599,7 @@ class PropertyLoader(StrategizedProperty):
 
             if self.backref is not None:
                 self.backref.compile(self)
-        elif not sessionlib.attribute_manager.is_class_managed(self.parent.class_, self.key):
+        elif not mapper.class_mapper(self.parent.class_).get_property(self.key, raiseerr=False):
             raise exceptions.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'.  New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (self.key, self.parent.class_.__name__, self.parent.class_.__name__))
 
         super(PropertyLoader, self).do_init()
@@ -702,28 +702,5 @@ class BackRef(object):
 
         return attributes.GenericBackrefExtension(self.key)
 
-def deferred_load(instance, props):
-    """set multiple instance attributes to 'deferred' or 'lazy' load, for the given set of MapperProperty objects.
-
-    this will remove the current value of the attribute and set a per-instance
-    callable to fire off when the instance is next accessed.
-    
-    for column-based properties, aggreagtes them into a single list against a single deferred loader
-    so that a single column access loads all columns
-
-    """
-
-    if not props:
-        return
-    column_props = [p for p in props if isinstance(p, ColumnProperty)]
-    callable_ = column_props[0]._get_strategy(strategies.DeferredColumnLoader).setup_loader(instance, props=column_props)
-    for p in column_props:
-        sessionlib.attribute_manager.init_instance_attribute(instance, p.key, callable_=callable_, clear=True)
-        
-    for p in [p for p in props if isinstance(p, PropertyLoader)]:
-        callable_ = p._get_strategy(strategies.LazyLoader).setup_loader(instance)
-        sessionlib.attribute_manager.init_instance_attribute(instance, p.key, callable_=callable_, clear=True)
-
 mapper.ColumnProperty = ColumnProperty
-mapper.deferred_load = deferred_load
         
index a4460ea2542f94274b1c060d6b88070364f43b6d..b828750d4a6f5920dad533de5699baeed0c09564 100644 (file)
@@ -50,6 +50,8 @@ class Query(object):
         self._attributes = {}
         self._current_path = ()
         self._primary_adapter=None
+        self._only_load_props = None
+        self._refresh_instance = None
         
     def _clone(self):
         q = Query.__new__(Query)
@@ -103,12 +105,14 @@ class Query(object):
         key column values in the order of the table def's primary key
         columns.
         """
-
+        print "LOAD CHECK1"
         ret = self._extension.load(self, ident, **kwargs)
         if ret is not mapper.EXT_CONTINUE:
             return ret
+        print "LOAD CHECK2"
         key = self.mapper.identity_key_from_primary_key(ident)
-        instance = self._get(key, ident, reload=True, **kwargs)
+        instance = self.populate_existing()._get(key, ident, **kwargs)
+        print "LOAD CHECK3"
         if instance is None and raiseerr:
             raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
         return instance
@@ -655,7 +659,7 @@ class Query(object):
         return self._execute_and_instances(context)
     
     def _execute_and_instances(self, querycontext):
-        result = self.session.execute(querycontext.statement, params=self._params, mapper=self.mapper)
+        result = self.session.execute(querycontext.statement, params=self._params, mapper=self.mapper, instance=self._refresh_instance)
         try:
             return iter(self.instances(result, querycontext=querycontext))
         finally:
@@ -670,8 +674,6 @@ class Query(object):
         and add_column().
         """
 
-        self.__log_debug("instances()")
-
         session = self.session
 
         context = kwargs.pop('querycontext', None)
@@ -713,54 +715,68 @@ class Query(object):
             result = []
         else:
             result = util.UniqueAppender([])
-                    
+        
+        primary_mapper_args = dict(extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance)
+        
         for row in cursor.fetchall():
             if self._primary_adapter:
-                self.select_mapper._instance(context, self._primary_adapter(row), result)
+                self.select_mapper._instance(context, self._primary_adapter(row), result, **primary_mapper_args)
             else:
-                self.select_mapper._instance(context, row, result)
+                self.select_mapper._instance(context, row, result, **primary_mapper_args)
             for proc in process:
                 proc[0](context, row)
 
         for instance in context.identity_map.values():
             context.attributes.get(('populating_mapper', id(instance)), object_mapper(instance))._post_instance(context, instance)
-        
+
+        if context.refresh_instance and context.only_load_props and context.refresh_instance._instance_key in context.identity_map:
+            # if refreshing partial instance, do special state commit
+            # affecting only the refreshed attributes
+            context.refresh_instance._state.commit(context.only_load_props)
+            del context.identity_map[context.refresh_instance._instance_key]
+            
         # store new stuff in the identity map
         for instance in context.identity_map.values():
             session._register_persistent(instance)
-
+        
         if mappers_or_columns:
             return list(util.OrderedSet(zip(*([result] + [o[1] for o in process]))))
         else:
             return result.data
 
 
-    def _get(self, key, ident=None, reload=False, lockmode=None):
+    def _get(self, key=None, ident=None, refresh_instance=None, lockmode=None, only_load_props=None):
         lockmode = lockmode or self._lockmode
-        if not reload and not self.mapper.always_refresh and lockmode is None:
+        if not self._populate_existing and not refresh_instance and not self.mapper.always_refresh and lockmode is None:
             try:
                 return self.session.identity_map[key]
             except KeyError:
                 pass
-
+            
         if ident is None:
-            ident = key[1]
+            if key is not None:
+                ident = key[1]
         else:
             ident = util.to_list(ident)
-        params = {}
+
+        q = self
         
-        (_get_clause, _get_params) = self.select_mapper._get_clause
-        for i, primary_key in enumerate(self.primary_key_columns):
-            try:
-                params[_get_params[primary_key].key] = ident[i]
-            except IndexError:
-                raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in self.primary_key_columns]))
+        if ident is not None:
+            params = {}
+            (_get_clause, _get_params) = self.select_mapper._get_clause
+            q = q.filter(_get_clause)
+            for i, primary_key in enumerate(self.primary_key_columns):
+                try:
+                    params[_get_params[primary_key].key] = ident[i]
+                except IndexError:
+                    raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in self.primary_key_columns]))
+            q = q.params(params)
+            
         try:
-            q = self
             if lockmode is not None:
                 q = q.with_lockmode(lockmode)
-            q = q.filter(_get_clause)
-            q = q.params(params)._select_context_options(populate_existing=reload, version_check=(lockmode is not None))
+            q = q._select_context_options(populate_existing=refresh_instance is not None, version_check=(lockmode is not None), only_load_props=only_load_props, refresh_instance=refresh_instance)
+            q = q.order_by(None)
             # call using all() to avoid LIMIT compilation complexity
             return q.all()[0]
         except IndexError:
@@ -861,7 +877,9 @@ class Query(object):
         # TODO: doing this off the select_mapper.  if its the polymorphic mapper, then
         # it has no relations() on it.  should we compile those too into the query ?  (i.e. eagerloads)
         for value in self.select_mapper.iterate_properties:
-            context.exec_with_path(self.select_mapper, value.key, value.setup, context)
+            if self._only_load_props and value.key not in self._only_load_props:
+                continue
+            context.exec_with_path(self.select_mapper, value.key, value.setup, context, only_load_props=self._only_load_props)
 
         # additional entities/columns, add those to selection criterion
         for tup in self._entities:
@@ -912,7 +930,6 @@ class Query(object):
             statement.append_order_by(*context.eager_order_by)
         else:
             statement = sql.select(context.primary_columns + context.secondary_columns, whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **self._select_args())
-
             if context.eager_joins:
                 statement.append_from(context.eager_joins, _copy_collection=False)
 
@@ -1101,11 +1118,15 @@ class Query(object):
         q._select_context_options(**kwargs)
         return list(q)
 
-    def _select_context_options(self, populate_existing=None, version_check=None): #pragma: no cover
-        if populate_existing is not None:
+    def _select_context_options(self, populate_existing=None, version_check=None, only_load_props=None, refresh_instance=None): #pragma: no cover
+        if populate_existing:
             self._populate_existing = populate_existing
-        if version_check is not None:
+        if version_check:
             self._version_check = version_check
+        if refresh_instance is not None:
+            self._refresh_instance = refresh_instance
+        if only_load_props:
+            self._only_load_props = util.Set(only_load_props)
         return self
         
     def join_to(self, key): #pragma: no cover
@@ -1207,6 +1228,8 @@ class QueryContext(object):
         self.statement = None
         self.populate_existing = query._populate_existing
         self.version_check = query._version_check
+        self.only_load_props = query._only_load_props
+        self.refresh_instance = query._refresh_instance
         self.identity_map = {}
         self.path = ()
         self.primary_columns = []
index 0b57447788f96c1d8fe2da46935792887e14a01f..7cdc7f671998eeef72fc75615445b4d820d8ac0f 100644 (file)
@@ -6,7 +6,7 @@
 
 import weakref
 from sqlalchemy import util, exceptions, sql, engine
-from sqlalchemy.orm import unitofwork, query, util as mapperutil
+from sqlalchemy.orm import unitofwork, query, attributes, util as mapperutil
 from sqlalchemy.orm.mapper import object_mapper as _object_mapper
 from sqlalchemy.orm.mapper import class_mapper as _class_mapper
 from sqlalchemy.orm.mapper import Mapper
@@ -522,9 +522,9 @@ class Session(object):
         resources of the underlying ``Connection``.
         """
 
-        engine = self.get_bind(mapper, clause=clause)
+        engine = self.get_bind(mapper, clause=clause, **kwargs)
         
-        return self.__connection(engine, close_with_result=True).execute(clause, params or {}, **kwargs)
+        return self.__connection(engine, close_with_result=True).execute(clause, params or {})
 
     def scalar(self, clause, params=None, mapper=None, **kwargs):
         """Like execute() but return a scalar result."""
@@ -716,26 +716,57 @@ class Session(object):
         entity_name = kwargs.pop('entity_name', None)
         return self.query(class_, entity_name=entity_name).load(ident, **kwargs)
 
-    def refresh(self, obj):
-        """Reload the attributes for the given object from the
-        database, clear any changes made.
+    def refresh(self, obj, attribute_names=None):
+        """Refresh the attributes on the given instance.
+        
+        When called, a query will be issued
+        to the database which will refresh all attributes with their
+        current value.  
+        
+        Lazy-loaded relational attributes will remain lazily loaded, so that 
+        the instance-wide refresh operation will be followed
+        immediately by the lazy load of that attribute.
+        
+        Eagerly-loaded relational attributes will eagerly load within the
+        single refresh operation.
+        
+        The ``attribute_names`` argument is an iterable collection
+        of attribute names indicating a subset of attributes to be 
+        refreshed.
         """
 
         self._validate_persistent(obj)
-        if self.query(obj.__class__)._get(obj._instance_key, reload=True) is None:
+            
+        if self.query(obj.__class__)._get(obj._instance_key, refresh_instance=obj, only_load_props=attribute_names) is None:
             raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(obj))
 
-    def expire(self, obj):
-        """Mark the given object as expired.
-
-        This will add an instrumentation to all mapped attributes on
-        the instance such that when an attribute is next accessed, the
-        session will reload all attributes on the instance from the
-        database.
+    def expire(self, obj, attribute_names=None):
+        """Expire the attributes on the given instance.
+        
+        The instance's attributes are instrumented such that
+        when an attribute is next accessed, a query will be issued
+        to the database which will refresh all attributes with their
+        current value.  
+        
+        Lazy-loaded relational attributes will remain lazily loaded, so that 
+        triggering one will incur the instance-wide refresh operation, followed
+        immediately by the lazy load of that attribute.
+        
+        Eagerly-loaded relational attributes will eagerly load within the
+        single refresh operation.
+        
+        The ``attribute_names`` argument is an iterable collection
+        of attribute names indicating a subset of attributes to be 
+        expired.
         """
-
-        for c in [obj] + list(_object_mapper(obj).cascade_iterator('refresh-expire', obj)):
-            self._expire_impl(c)
+        
+        if attribute_names:
+            self._validate_persistent(obj)
+            expire_instance(obj, attribute_names=attribute_names)
+        else:
+            for c in [obj] + list(_object_mapper(obj).cascade_iterator('refresh-expire', obj)):
+                self._validate_persistent(obj)
+                expire_instance(c, None)
 
     def prune(self):
         """Removes unreferenced instances cached in the identity map.
@@ -750,21 +781,12 @@ class Session(object):
 
         return self.uow.prune_identity_map()
 
-    def _expire_impl(self, obj):
-        self._validate_persistent(obj)
-
-        def exp():
-            if self.query(obj.__class__)._get(obj._instance_key, reload=True) is None:
-                raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(obj))
-
-        attribute_manager.trigger_history(obj, exp)
-
     def is_expired(self, obj, unexpire=False):
         """Return True if the given object has been marked as expired."""
 
-        ret = attribute_manager.has_trigger(obj)
+        ret = obj._state.trigger is not None
         if ret and unexpire:
-            attribute_manager.untrigger_history(obj)
+            obj._state.trigger = None
         return ret
 
     def expunge(self, object):
@@ -999,7 +1021,7 @@ class Session(object):
     def _register_persistent(self, obj):
         obj._sa_session_id = self.hash_key
         self.identity_map[obj._instance_key] = obj
-        attribute_manager.commit(obj)
+        obj._state.commit_all()
 
     def _attach(self, obj):
         old_id = getattr(obj, '_sa_session_id', None)
@@ -1083,6 +1105,25 @@ class Session(object):
     new = property(lambda s:s.uow.new,
                    doc="A ``Set`` of all objects marked as 'new' within this ``Session``.")
 
+def expire_instance(obj, attribute_names):
+    """standalone expire instance function. 
+    
+    installs a callable with the given instance's _state
+    which will fire off when any of the named attributes are accessed;
+    their existing value is removed.
+    
+    If the list is None or blank, the entire instance is expired.
+    """
+    
+    if obj._state.trigger is None:
+        def load_attributes(instance, attribute_names):
+            if object_session(instance).query(instance.__class__)._get(instance._instance_key, refresh_instance=instance, only_load_props=attribute_names) is None:
+                raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
+        obj._state.trigger = load_attributes
+        
+    obj._state.expire_attributes(attribute_names)
+    
+
 
 # this is the AttributeManager instance used to provide attribute behavior on objects.
 # to all the "global variable police" out there:  its a stateless object.
@@ -1108,3 +1149,4 @@ def object_session(obj):
 unitofwork.object_session = object_session
 from sqlalchemy.orm import mapper
 mapper.attribute_manager = attribute_manager
+mapper.expire_instance = expire_instance
\ No newline at end of file
index 8574d2fef22b883e6aa77586efb6c726e275ccf5..dfd1efa36e224e92f3b03ff957c3ad96e8f0efb7 100644 (file)
@@ -114,8 +114,7 @@ class ColumnLoader(LoaderStrategy):
             
             def new_execute(instance, row, isnew, **flags):
                 if isnew:
-                    loader = strategy.setup_loader(instance, props=props, create_statement=create_statement)
-                    sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=loader)
+                    instance._state.set_callable(self.key, strategy.setup_loader(instance, props=props, create_statement=create_statement))
                     
             if self._should_log_debug:
                 self.logger.debug("Returning deferred column fetcher for %s %s" % (mapper, self.key))
@@ -134,19 +133,19 @@ class DeferredColumnLoader(LoaderStrategy):
     """Deferred column loader, a per-column or per-column-group lazy loader."""
     
     def create_row_processor(self, selectcontext, mapper, row):
-        if (self.group is not None and selectcontext.attributes.get(('undefer', self.group), False)) or self.columns[0] in row:
+        if self.columns[0] in row:
             return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, mapper, row)
         elif not self.is_class_level or len(selectcontext.options):
             def new_execute(instance, row, **flags):
                 if self._should_log_debug:
                     self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key))
-                sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self.setup_loader(instance))
+                instance._state.set_callable(self.key, self.setup_loader(instance))
             return (new_execute, None, None)
         else:
             def new_execute(instance, row, **flags):
                 if self._should_log_debug:
                     self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key))
-                sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
+                instance._state.reset(self.key)
             return (new_execute, None, None)
 
     def init(self):
@@ -162,8 +161,11 @@ class DeferredColumnLoader(LoaderStrategy):
         self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__))
         sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, callable_=self.setup_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator)
 
-    def setup_query(self, context, **kwargs):
-        if self.group is not None and context.attributes.get(('undefer', self.group), False):
+    def setup_query(self, context, only_load_props=None, **kwargs):
+        if \
+            (self.group is not None and context.attributes.get(('undefer', self.group), False)) or \
+            (only_load_props and self.key in only_load_props):
+            
             self.parent_property._get_strategy(ColumnLoader).setup_query(context, **kwargs)
         
     def setup_loader(self, instance, props=None, create_statement=None):
@@ -198,27 +200,12 @@ class DeferredColumnLoader(LoaderStrategy):
                 raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
 
             if create_statement is None:
-                (clause, param_map) = localparent._get_clause
                 ident = instance._instance_key[1]
-                params = {}
-                for i, primary_key in enumerate(localparent.primary_key):
-                    params[param_map[primary_key].key] = ident[i]
-                statement = sql.select([p.columns[0] for p in group], clause, from_obj=[localparent.mapped_table], use_labels=True)
+                session.query(localparent)._get(None, ident=ident, only_load_props=[p.key for p in group], refresh_instance=instance)
             else:
                 statement, params = create_statement(instance)
-            
-            # TODO: have the "fetch of one row" operation go through the same channels as a query._get()
-            # deferred load of several attributes should be a specialized case of a query refresh operation
-            conn = session.connection(mapper=localparent, instance=instance)
-            result = conn.execute(statement, params)
-            try:
-                row = result.fetchone()
-                for prop in group:
-                    sessionlib.attribute_manager.set_committed_value(instance, prop.key, row[prop.columns[0]])
-                return attributes.ATTR_WAS_SET
-            finally:
-                result.close()
-
+                session.query(localparent).from_statement(statement).params(params)._get(None, only_load_props=[p.key for p in group], refresh_instance=instance)
+            return attributes.ATTR_WAS_SET
         return lazyload
                 
 DeferredColumnLoader.logger = logging.class_logger(DeferredColumnLoader)
@@ -248,7 +235,10 @@ class AbstractRelationLoader(LoaderStrategy):
         self._should_log_debug = logging.is_debug_enabled(self.logger)
         
     def _init_instance_attribute(self, instance, callable_=None):
-        return sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=callable_)
+        if callable_:
+            instance._state.set_callable(self.key, callable_)
+        else:
+            instance._state.initialize(self.key)
         
     def _register_attribute(self, class_, callable_=None, **kwargs):
         self.logger.info("register managed %s attribute %s on class %s" % ((self.uselist and "list-holding" or "scalar"), self.key, self.parent.class_.__name__))
@@ -399,7 +389,7 @@ class LazyLoader(AbstractRelationLoader):
                     # so that the class-level lazy loader is executed when next referenced on this instance.
                     # this usually is not needed unless the constructor of the object referenced the attribute before we got 
                     # to load data into it.
-                    sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
+                    instance._state.reset(self.key)
             return (new_execute, None, None)
 
     def _create_lazy_clause(cls, prop, reverse_direction=False):
@@ -603,10 +593,7 @@ class EagerLoader(AbstractRelationLoader):
                         # parent object, bypassing InstrumentedAttribute
                         # event handlers.
                         #
-                        # FIXME: instead of...
-                        sessionlib.attribute_manager.set_raw_value(instance, self.key, self.select_mapper._instance(selectcontext, decorated_row, None))
-                        # bypass and set directly:
-                        #instance.__dict__[self.key] = self.select_mapper._instance(selectcontext, decorated_row, None)
+                        instance.__dict__[self.key] = self.select_mapper._instance(selectcontext, decorated_row, None)
                     else:
                         # call _instance on the row, even though the object has been created,
                         # so that we further descend into properties
index 7f9a4d7d06e5f0a1f2454f26929216a7e4726d01..2cd7cb6f5dd958c9c46f74258a19537b2919b66a 100644 (file)
@@ -128,7 +128,7 @@ class UnitOfWork(object):
         if hasattr(obj, '_sa_insert_order'):
             delattr(obj, '_sa_insert_order')
         self.identity_map[obj._instance_key] = obj
-        attribute_manager.commit(obj)
+        obj._state.commit_all()
 
     def register_new(self, obj):
         """register the given object as 'new' (i.e. unsaved) within this unit of work."""
index 59357c7b71095db163cdee4c1f0cd942a059d564..059d7a100e71e4388f8c09a3c1aecc39fe707aff 100644 (file)
@@ -11,6 +11,7 @@ def suite():
         'orm.lazy_relations',
         'orm.eager_relations',
         'orm.mapper',
+        'orm.expire',
         'orm.selectable',
         'orm.collection',
         'orm.generative',
index 930bfa57e20a28692bc6f60b895dcbce64df2862..2080474eddceed506e3bf4d7e3f854d514074112 100644 (file)
@@ -5,6 +5,8 @@ from sqlalchemy.orm.collections import collection
 from sqlalchemy import exceptions
 from testlib import *
 
+ROLLBACK_SUPPORTED=False
+
 # these test classes defined at the module
 # level to support pickling
 class MyTest(object):pass
@@ -29,7 +31,7 @@ class AttributesTest(PersistTest):
         
         print repr(u.__dict__)
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
-        manager.commit(u)
+        u._state.commit_all()
         print repr(u.__dict__)
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
 
@@ -37,10 +39,11 @@ class AttributesTest(PersistTest):
         u.email_address = 'foo@bar.com'
         print repr(u.__dict__)
         self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.email_address == 'foo@bar.com')
-        
-        manager.rollback(u)
-        print repr(u.__dict__)
-        self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
+
+        if ROLLBACK_SUPPORTED:
+            manager.rollback(u)
+            print repr(u.__dict__)
+            self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
 
     def test_pickleness(self):
 
@@ -128,7 +131,7 @@ class AttributesTest(PersistTest):
 
         print repr(u.__dict__)
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
-        manager.commit(u, a)
+        u, a._state.commit_all()
         print repr(u.__dict__)
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
 
@@ -140,11 +143,12 @@ class AttributesTest(PersistTest):
         print repr(u.__dict__)
         self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com')
 
-        manager.rollback(u, a)
-        print repr(u.__dict__)
-        print repr(u.addresses[0].__dict__)
-        self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
-        self.assert_(len(manager.get_history(u, 'addresses').unchanged_items()) == 1)
+        if ROLLBACK_SUPPORTED:
+            manager.rollback(u, a)
+            print repr(u.__dict__)
+            print repr(u.addresses[0].__dict__)
+            self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
+            self.assert_(len(manager.get_history(u, 'addresses').unchanged_items()) == 1)
 
     def test_backref(self):
         class Student(object):pass
@@ -231,9 +235,9 @@ class AttributesTest(PersistTest):
         # create objects as if they'd been freshly loaded from the database (without history)
         b = Blog()
         p1 = Post()
-        manager.init_instance_attribute(b, 'posts', lambda:[p1])
-        manager.init_instance_attribute(p1, 'blog', lambda:b)
-        manager.commit(p1, b)
+        b._state.set_callable('posts', lambda:[p1])
+        p1._state.set_callable('blog', lambda:b)
+        p1, b._state.commit_all()
 
         # no orphans (called before the lazy loaders fire off)
         assert manager.has_parent(Blog, p1, 'posts', optimistic=True)
@@ -292,7 +296,7 @@ class AttributesTest(PersistTest):
         x.element = 'this is the element'
         hist = manager.get_history(x, 'element')
         assert hist.added_items() == ['this is the element']
-        manager.commit(x)
+        x._state.commit_all()
         hist = manager.get_history(x, 'element')
         assert hist.added_items() == []
         assert hist.unchanged_items() == ['this is the element']
@@ -320,7 +324,7 @@ class AttributesTest(PersistTest):
         manager.register_attribute(Bar, 'id', uselist=False, useobject=True)
 
         x = Foo()
-        manager.commit(x)
+        x._state.commit_all()
         x.col2.append(Bar(4))
         h = manager.get_history(x, 'col2')
         print h.added_items()
@@ -362,7 +366,7 @@ class AttributesTest(PersistTest):
         manager.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x], mutable_scalars=True, useobject=False)
         x = Foo()
         x.element = ['one', 'two', 'three']    
-        manager.commit(x)
+        x._state.commit_all()
         x.element[1] = 'five'
         assert manager.is_modified(x)
         
@@ -372,7 +376,7 @@ class AttributesTest(PersistTest):
         manager.register_attribute(Foo, 'element', uselist=False, useobject=False)
         x = Foo()
         x.element = ['one', 'two', 'three']    
-        manager.commit(x)
+        x._state.commit_all()
         x.element[1] = 'five'
         assert not manager.is_modified(x)
         
index 80e3982da208d17d116954387f3205188b2ad3c3..0bb253b198be51c197244e22854dcbd6a2a2f54b 100644 (file)
@@ -7,12 +7,10 @@ from testlib.fixtures import *
 
 from query import QueryTest
 
-class DynamicTest(QueryTest):
+class DynamicTest(FixtureTest):
     keep_mappers = False
-
-    def setup_mappers(self):
-        pass
-
+    keep_data = True
+    
     def test_basic(self):
         mapper(User, users, properties={
             'addresses':dynamic_loader(mapper(Address, addresses))
index 3e811b86bf19a0ee681438f5807770d75fc32e6e..52602ecae39ea32b8615f3653dfa8e6de289a309 100644 (file)
@@ -7,9 +7,10 @@ from testlib import *
 from testlib.fixtures import *
 from query import QueryTest
 
-class EagerTest(QueryTest):
+class EagerTest(FixtureTest):
     keep_mappers = False
-
+    keep_data = True
+    
     def setup_mappers(self):
         pass
 
diff --git a/test/orm/expire.py b/test/orm/expire.py
new file mode 100644 (file)
index 0000000..4301177
--- /dev/null
@@ -0,0 +1,439 @@
+"""test attribute/instance expiration, deferral of attributes, etc."""
+
+import testbase
+from sqlalchemy import *
+from sqlalchemy import exceptions
+from sqlalchemy.orm import *
+from testlib import *
+from testlib.fixtures import *
+
+class ExpireTest(FixtureTest):
+    keep_mappers = False
+    refresh_data = True
+    
+    def test_expire(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref='user'),
+            })
+        mapper(Address, addresses)
+            
+        sess = create_session()
+        u = sess.query(User).get(7)
+        assert len(u.addresses) == 1
+        u.name = 'foo'
+        del u.addresses[0]
+        sess.expire(u)
+        
+        assert 'name' not in u.__dict__
+        
+        def go():
+            assert u.name == 'jack'
+        self.assert_sql_count(testbase.db, go, 1)
+        assert 'name' in u.__dict__
+
+        # we're changing the database here, so if this test fails in the middle,
+        # it'll screw up the other tests which are hardcoded to 7/'jack'
+        u.name = 'foo'
+        sess.flush()
+        # change the value in the DB
+        users.update(users.c.id==7, values=dict(name='jack')).execute()
+        sess.expire(u)
+        # object isnt refreshed yet, using dict to bypass trigger
+        assert u.__dict__.get('name') != 'jack'
+        # reload all
+        sess.query(User).all()
+        # test that it refreshed
+        assert u.__dict__['name'] == 'jack'
+
+        # object should be back to normal now,
+        # this should *not* produce a SELECT statement (not tested here though....)
+        assert u.name == 'jack'
+    
+    def test_expire_doesntload_on_set(self):
+        mapper(User, users)
+        
+        sess = create_session()
+        u = sess.query(User).get(7)
+        
+        sess.expire(u, attribute_names=['name'])
+        def go():
+            u.name = 'somenewname'
+        self.assert_sql_count(testbase.db, go, 0)
+        sess.flush()
+        sess.clear()
+        assert sess.query(User).get(7).name == 'somenewname'
+        
+    def test_expire_committed(self):
+        """test that the committed state of the attribute receives the most recent DB data"""
+        mapper(Order, orders)
+            
+        sess = create_session()
+        o = sess.query(Order).get(3)
+        sess.expire(o)
+
+        orders.update(id=3).execute(description='order 3 modified')
+        assert o.isopen == 1
+        assert o._state.committed_state['description'] == 'order 3 modified'
+        def go():
+            sess.flush()
+        self.assert_sql_count(testbase.db, go, 0)
+        
+    def test_expire_cascade(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, cascade="all, refresh-expire")
+        })
+        mapper(Address, addresses)
+        s = create_session()
+        u = s.get(User, 8)
+        assert u.addresses[0].email_address == 'ed@wood.com'
+
+        u.addresses[0].email_address = 'someotheraddress'
+        s.expire(u)
+        u.name
+        print u._state.dict
+        assert u.addresses[0].email_address == 'ed@wood.com'
+
+    def test_expired_lazy(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref='user'),
+            })
+        mapper(Address, addresses)
+
+        sess = create_session()
+        u = sess.query(User).get(7)
+
+        sess.expire(u)
+        assert 'name' not in u.__dict__
+        assert 'addresses' not in u.__dict__
+
+        def go():
+            assert u.addresses[0].email_address == 'jack@bean.com'
+            assert u.name == 'jack'
+        # two loads 
+        self.assert_sql_count(testbase.db, go, 2)
+        assert 'name' in u.__dict__
+        assert 'addresses' in u.__dict__
+
+    def test_expired_eager(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref='user', lazy=False),
+            })
+        mapper(Address, addresses)
+
+        sess = create_session()
+        u = sess.query(User).get(7)
+
+        sess.expire(u)
+        assert 'name' not in u.__dict__
+        assert 'addresses' not in u.__dict__
+
+        def go():
+            assert u.addresses[0].email_address == 'jack@bean.com'
+            assert u.name == 'jack'
+        # one load
+        self.assert_sql_count(testbase.db, go, 1)
+        assert 'name' in u.__dict__
+        assert 'addresses' in u.__dict__
+
+    def test_partial_expire(self):
+        mapper(Order, orders)
+
+        sess = create_session()
+        o = sess.query(Order).get(3)
+        
+        sess.expire(o, attribute_names=['description'])
+        assert 'id' in o.__dict__
+        assert 'description' not in o.__dict__
+        assert o._state.committed_state['isopen'] == 1
+        
+        orders.update(orders.c.id==3).execute(description='order 3 modified')
+        
+        def go():
+            assert o.description == 'order 3 modified'
+        self.assert_sql_count(testbase.db, go, 1)
+        assert o._state.committed_state['description'] == 'order 3 modified'
+        
+        o.isopen = 5
+        sess.expire(o, attribute_names=['description'])
+        assert 'id' in o.__dict__
+        assert 'description' not in o.__dict__
+        assert o.__dict__['isopen'] == 5
+        assert o._state.committed_state['isopen'] == 1
+        
+        def go():
+            assert o.description == 'order 3 modified'
+        self.assert_sql_count(testbase.db, go, 1)
+        assert o.__dict__['isopen'] == 5
+        assert o._state.committed_state['description'] == 'order 3 modified'
+        assert o._state.committed_state['isopen'] == 1
+
+        sess.flush()
+        
+        sess.expire(o, attribute_names=['id', 'isopen', 'description'])
+        assert 'id' not in o.__dict__
+        assert 'isopen' not in o.__dict__
+        assert 'description' not in o.__dict__
+        def go():
+            assert o.description == 'order 3 modified'
+            assert o.id == 3
+            assert o.isopen == 5
+        self.assert_sql_count(testbase.db, go, 1)
+
+    def test_partial_expire_lazy(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref='user'),
+            })
+        mapper(Address, addresses)
+
+        sess = create_session()
+        u = sess.query(User).get(8)
+        
+        sess.expire(u, ['name', 'addresses'])
+        assert 'name' not in u.__dict__
+        assert 'addresses' not in u.__dict__
+        
+        # hit the lazy loader.  just does the lazy load,
+        # doesnt do the overall refresh
+        def go():
+            assert u.addresses[0].email_address=='ed@wood.com'
+        self.assert_sql_count(testbase.db, go, 1)
+        
+        assert 'name' not in u.__dict__
+        
+        # check that mods to expired lazy-load attributes 
+        # only do the lazy load
+        sess.expire(u, ['name', 'addresses'])
+        def go():
+            u.addresses = [Address(id=10, email_address='foo@bar.com')]
+        self.assert_sql_count(testbase.db, go, 1)
+        
+        sess.flush()
+        
+        # flush has occurred, and addresses was modified, 
+        # so the addresses collection got committed and is
+        # longer expired
+        def go():
+            assert u.addresses[0].email_address=='foo@bar.com'
+            assert len(u.addresses) == 1
+        self.assert_sql_count(testbase.db, go, 0)
+        
+        # but the name attribute was never loaded and so
+        # still loads
+        def go():
+            assert u.name == 'ed'
+        self.assert_sql_count(testbase.db, go, 1)
+
+    def test_partial_expire_eager(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref='user', lazy=False),
+            })
+        mapper(Address, addresses)
+
+        sess = create_session()
+        u = sess.query(User).get(8)
+
+        sess.expire(u, ['name', 'addresses'])
+        assert 'name' not in u.__dict__
+        assert 'addresses' not in u.__dict__
+
+        def go():
+            assert u.addresses[0].email_address=='ed@wood.com'
+        self.assert_sql_count(testbase.db, go, 1)
+
+        # check that mods to expired eager-load attributes 
+        # do the refresh
+        sess.expire(u, ['name', 'addresses'])
+        def go():
+            u.addresses = [Address(id=10, email_address='foo@bar.com')]
+        self.assert_sql_count(testbase.db, go, 1)
+        sess.flush()
+
+        # this should ideally trigger the whole load
+        # but currently it works like the lazy case
+        def go():
+            assert u.addresses[0].email_address=='foo@bar.com'
+            assert len(u.addresses) == 1
+        self.assert_sql_count(testbase.db, go, 0)
+        
+        def go():
+            assert u.name == 'ed'
+        # scalar attributes have their own load
+        self.assert_sql_count(testbase.db, go, 1)
+        # ideally, this was already loaded, but we arent
+        # doing it that way right now
+        #self.assert_sql_count(testbase.db, go, 0)
+
+    def test_partial_expire_deferred(self):
+        mapper(Order, orders, properties={
+            'description':deferred(orders.c.description)
+        })
+        
+        sess = create_session()
+        o = sess.query(Order).get(3)
+        sess.expire(o, ['description', 'isopen'])
+        assert 'isopen' not in o.__dict__
+        assert 'description' not in o.__dict__
+        
+        # test that expired attribute access refreshes
+        # the deferred
+        def go():
+            assert o.isopen == 1
+            assert o.description == 'order 3'
+        self.assert_sql_count(testbase.db, go, 1)
+        
+        sess.expire(o, ['description', 'isopen'])
+        assert 'isopen' not in o.__dict__
+        assert 'description' not in o.__dict__
+        # test that the deferred attribute triggers the full
+        # reload
+        def go():
+            assert o.description == 'order 3'
+            assert o.isopen == 1
+        self.assert_sql_count(testbase.db, go, 1)
+        
+        clear_mappers()
+        
+        mapper(Order, orders)
+        sess.clear()
+
+        # same tests, using deferred at the options level
+        o = sess.query(Order).options(defer('description')).get(3)
+
+        assert 'description' not in o.__dict__
+
+        # sanity check
+        def go():
+            assert o.description == 'order 3'
+        self.assert_sql_count(testbase.db, go, 1)
+         
+        assert 'description' in o.__dict__
+        assert 'isopen' in o.__dict__
+        sess.expire(o, ['description', 'isopen'])
+        assert 'isopen' not in o.__dict__
+        assert 'description' not in o.__dict__
+        
+        # test that expired attribute access refreshes
+        # the deferred
+        def go():
+            assert o.isopen == 1
+            assert o.description == 'order 3'
+        self.assert_sql_count(testbase.db, go, 1)
+        sess.expire(o, ['description', 'isopen'])
+
+        assert 'isopen' not in o.__dict__
+        assert 'description' not in o.__dict__
+        # test that the deferred attribute triggers the full
+        # reload
+        def go():
+            assert o.description == 'order 3'
+            assert o.isopen == 1
+        self.assert_sql_count(testbase.db, go, 1)
+        
+
+class RefreshTest(FixtureTest):
+    keep_mappers = False
+    refresh_data = True
+
+    def test_refresh(self):
+        mapper(User, users, properties={
+            'addresses':relation(mapper(Address, addresses), backref='user')
+        })
+        s = create_session()
+        u = s.get(User, 7)
+        u.name = 'foo'
+        a = Address()
+        assert object_session(a) is None
+        u.addresses.append(a)
+        assert a.email_address is None
+        assert id(a) in [id(x) for x in u.addresses]
+
+        s.refresh(u)
+
+        # its refreshed, so not dirty
+        assert u not in s.dirty
+
+        # username is back to the DB
+        assert u.name == 'jack'
+        
+        assert id(a) not in [id(x) for x in u.addresses]
+
+        u.name = 'foo'
+        u.addresses.append(a)
+        # now its dirty
+        assert u in s.dirty
+        assert u.name == 'foo'
+        assert id(a) in [id(x) for x in u.addresses]
+        s.expire(u)
+
+        # get the attribute, it refreshes
+        assert u.name == 'jack'
+        assert id(a) not in [id(x) for x in u.addresses]
+
+    def test_refresh_expired(self):
+        mapper(User, users)
+        s = create_session()
+        u = s.get(User, 7)
+        s.expire(u)
+        assert 'name' not in u.__dict__
+        s.refresh(u)
+        assert u.name == 'jack'
+        
+    def test_refresh_with_lazy(self):
+        """test that when a lazy loader is set as a trigger on an object's attribute 
+        (at the attribute level, not the class level), a refresh() operation doesnt 
+        fire the lazy loader or create any problems"""
+        
+        s = create_session()
+        mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))})
+        q = s.query(User).options(lazyload('addresses'))
+        u = q.filter(users.c.id==8).first()
+        def go():
+            s.refresh(u)
+        self.assert_sql_count(testbase.db, go, 1)
+
+
+    def test_refresh_with_eager(self):
+        """test that a refresh/expire operation loads rows properly and sends correct "isnew" state to eager loaders"""
+        
+        mapper(User, users, properties={
+            'addresses':relation(mapper(Address, addresses), lazy=False)
+        })
+        
+        s = create_session()
+        u = s.get(User, 8)
+        assert len(u.addresses) == 3
+        s.refresh(u)
+        assert len(u.addresses) == 3
+
+        s = create_session()
+        u = s.get(User, 8)
+        assert len(u.addresses) == 3
+        s.expire(u)
+        assert len(u.addresses) == 3
+
+    @testing.fails_on('maxdb')
+    def test_refresh2(self):
+        """test a hang condition that was occuring on expire/refresh"""
+
+        s = create_session()
+        mapper(Address, addresses)
+
+        mapper(User, users, properties = dict(addresses=relation(Address,cascade="all, delete-orphan",lazy=False)) )
+        
+        u=User()
+        u.name='Justin'
+        a = Address(id=10, email_address='lala')
+        u.addresses.append(a)
+        
+        s.save(u)
+        s.flush()
+        s.clear()
+        u = s.query(User).filter(User.name=='Justin').one()
+
+        s.expire(u)
+        assert u.name == 'Justin'
+
+        s.refresh(u)
+
+if __name__ == '__main__':
+    testbase.main()
index 9fa7fffbac6b61e4c02e27d66ded7103d42ec3f9..32420300f2ddc2910d9dbd38377011f3c587fce5 100644 (file)
@@ -84,7 +84,7 @@ class GetTest(ORMTest):
             Column('bar_id', Integer, ForeignKey('bar.id')),
             Column('data', String(20)))
 
-    def create_test(polymorphic):
+    def create_test(polymorphic, name):
         def test_get(self):
             class Foo(object):
                 pass
@@ -145,11 +145,11 @@ class GetTest(ORMTest):
                     assert sess.query(Blub).get(bl.id) == bl
 
                 self.assert_sql_count(testbase.db, go, 3)
-
+        test_get.__name__ = name
         return test_get
 
-    test_get_polymorphic = create_test(True)
-    test_get_nonpolymorphic = create_test(False)
+    test_get_polymorphic = create_test(True, 'test_get_polymorphic')
+    test_get_nonpolymorphic = create_test(False, 'test_get_nonpolymorphic')
 
 
 class ConstructionTest(ORMTest):
index b8e92c16377a5782ff8fbb18255860423131e8f1..97eda3006327243f6bfb68770c51c95f17487d1e 100644 (file)
@@ -8,12 +8,10 @@ from testlib import *
 from testlib.fixtures import *
 from query import QueryTest
 
-class LazyTest(QueryTest):
+class LazyTest(FixtureTest):
     keep_mappers = False
-
-    def setup_mappers(self):
-        pass
-        
+    keep_data = True
+    
     def test_basic(self):
         mapper(User, users, properties={
             'addresses':relation(mapper(Address, addresses), lazy=True)
@@ -275,13 +273,10 @@ class LazyTest(QueryTest):
         
         assert a.user is u1
 
-class M2OGetTest(QueryTest):
+class M2OGetTest(FixtureTest):
     keep_mappers = False
-    keep_data = False
+    keep_data = True
 
-    def setup_mappers(self):
-        pass
-        
     def test_m2o_noload(self):
         """test that a NULL foreign key doesn't trigger a lazy load"""
         mapper(User, users)
index fe763c0bed8f49106004b0d719fb2c7ffcb6acd8..5ed4795941f37a059d82b403700aa1a0c91bcc71 100644 (file)
@@ -67,67 +67,12 @@ class MapperTest(MapperSuperTest):
         u2 = s.query(User).filter_by(user_name='jack').one()
         assert u is u2
 
-    def test_refresh(self):
-        mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), backref='user')})
-        s = create_session()
-        u = s.get(User, 7)
-        u.user_name = 'foo'
-        a = Address()
-        assert object_session(a) is None
-        u.addresses.append(a)
-
-        self.assert_(a in u.addresses)
-
-        s.refresh(u)
-
-        # its refreshed, so not dirty
-        self.assert_(u not in s.dirty)
-
-        # username is back to the DB
-        self.assert_(u.user_name == 'jack')
-
-        self.assert_(a not in u.addresses)
-
-        u.user_name = 'foo'
-        u.addresses.append(a)
-        # now its dirty
-        self.assert_(u in s.dirty)
-        self.assert_(u.user_name == 'foo')
-        self.assert_(a in u.addresses)
-        s.expire(u)
-
-        # get the attribute, it refreshes
-        self.assert_(u.user_name == 'jack')
-        self.assert_(a not in u.addresses)
 
     def test_compileonsession(self):
         m = mapper(User, users)
         session = create_session()
         session.connection(m)
 
-    def test_expirecascade(self):
-        mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), cascade="all, refresh-expire")})
-        s = create_session()
-        u = s.get(User, 8)
-        u.addresses[0].email_address = 'someotheraddress'
-        s.expire(u)
-        assert u.addresses[0].email_address == 'ed@wood.com'
-
-    def test_refreshwitheager(self):
-        """test that a refresh/expire operation loads rows properly and sends correct "isnew" state to eager loaders"""
-        mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)})
-        s = create_session()
-        u = s.get(User, 8)
-        assert len(u.addresses) == 3
-        s.refresh(u)
-        assert len(u.addresses) == 3
-
-        s = create_session()
-        u = s.get(User, 8)
-        assert len(u.addresses) == 3
-        s.expire(u)
-        assert len(u.addresses) == 3
-
     def test_incompletecolumns(self):
         """test loading from a select which does not contain all columns"""
         mapper(Address, addresses)
@@ -186,71 +131,6 @@ class MapperTest(MapperSuperTest):
         except Exception, e:
             assert e is ex
 
-    def test_refresh_lazy(self):
-        """test that when a lazy loader is set as a trigger on an object's attribute (at the attribute level, not the class level), a refresh() operation doesnt fire the lazy loader or create any problems"""
-        s = create_session()
-        mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))})
-        q2 = s.query(User).options(lazyload('addresses'))
-        u = q2.selectfirst(users.c.user_id==8)
-        def go():
-            s.refresh(u)
-        self.assert_sql_count(testbase.db, go, 1)
-
-    def test_expire(self):
-        """test the expire function"""
-        s = create_session()
-        mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)})
-        u = s.get(User, 7)
-        assert(len(u.addresses) == 1)
-        u.user_name = 'foo'
-        del u.addresses[0]
-        s.expire(u)
-        # test plain expire
-        self.assert_(u.user_name =='jack')
-        self.assert_(len(u.addresses) == 1)
-
-        # we're changing the database here, so if this test fails in the middle,
-        # it'll screw up the other tests which are hardcoded to 7/'jack'
-        u.user_name = 'foo'
-        s.flush()
-        # change the value in the DB
-        users.update(users.c.user_id==7, values=dict(user_name='jack')).execute()
-        s.expire(u)
-        # object isnt refreshed yet, using dict to bypass trigger
-        self.assert_(u.__dict__.get('user_name') != 'jack')
-        # do a select
-        s.query(User).select()
-        # test that it refreshed
-        self.assert_(u.__dict__['user_name'] == 'jack')
-
-        # object should be back to normal now,
-        # this should *not* produce a SELECT statement (not tested here though....)
-        self.assert_(u.user_name =='jack')
-
-    @testing.fails_on('maxdb')
-    def test_refresh2(self):
-        """test a hang condition that was occuring on expire/refresh"""
-
-        s = create_session()
-        m1 = mapper(Address, addresses)
-
-        m2 = mapper(User, users, properties = dict(addresses=relation(Address,private=True,lazy=False)) )
-        u=User()
-        u.user_name='Justin'
-        a = Address()
-        a.address_id=17  # to work around the hardcoded IDs in this test suite....
-        u.addresses.append(a)
-        s.flush()
-        s.clear()
-        u = s.query(User).selectfirst()
-        print u.user_name
-
-        #ok so far
-        s.expire(u)        #hangs when
-        print u.user_name #this line runs
-
-        s.refresh(u) #hangs
-
     def test_props(self):
         m = mapper(User, users, properties = {
             'addresses' : relation(mapper(Address, addresses))
@@ -299,7 +179,18 @@ class MapperTest(MapperSuperTest):
         sess.save(u3)
         sess.flush()
         sess.rollback()
-
+    
+    def test_illegal_non_primary(self):
+        mapper(User, users)
+        mapper(Address, addresses)
+        try:
+            mapper(User, users, non_primary=True, properties={
+                'addresses':relation(Address)
+            }).compile()
+            assert False
+        except exceptions.ArgumentError, e:
+            assert "Attempting to assign a new relation 'addresses' to a non-primary mapper on class 'User'" in str(e)
+        
     def test_propfilters(self):
         t = Table('person', MetaData(),
                   Column('id', Integer, primary_key=True),
index 4f19c2c32a8b96d4490c8f4034244c6c2e7f641d..438bc9634f0b13996256075722c6ebf31831c352 100644 (file)
@@ -12,16 +12,11 @@ from testlib.fixtures import *
 class QueryTest(FixtureTest):
     keep_mappers = True
     keep_data = True
-
+    
     def setUpAll(self):
         super(QueryTest, self).setUpAll()
-        install_fixture_data()
         self.setup_mappers()
 
-    def tearDownAll(self):
-        clear_mappers()
-        super(QueryTest, self).tearDownAll()
-
     def setup_mappers(self):
         mapper(User, users, properties={
             'addresses':relation(Address, backref='user'),
index b985cc8a50f8b1ed19b4383edf219cae005285f6..28efbf056c62cc7a31fa77879fbf43c19f69bb1b 100644 (file)
@@ -119,8 +119,10 @@ class VersioningTest(ORMTest):
         except exceptions.ConcurrentModificationError, e:
             assert True
         # reload it
+        print "RELOAD"
         s1.query(Foo).load(f1s1.id)
         # now assert version OK
+        print "VERSIONCHECK"
         s1.query(Foo).with_lockmode('read').get(f1s1.id)
         
         # assert brand new load is OK too
index 9105a29378f1a5225c15f0b467eb699fc8272db2..022fce094349807b406147592d4ba2231eaa6788 100644 (file)
@@ -193,6 +193,17 @@ def install_fixture_data():
     )
 
 class FixtureTest(ORMTest):
+    refresh_data = False
+    
+    def setUpAll(self):
+        super(FixtureTest, self).setUpAll()
+        if self.keep_data:
+            install_fixture_data()
+    
+    def setUp(self):
+        if self.refresh_data:
+            install_fixture_data()
+            
     def define_tables(self, meta):
         pass
 FixtureTest.metadata = metadata