From: Mike Bayer Date: Sun, 18 Nov 2007 02:13:56 +0000 (+0000) Subject: - session.refresh() and session.expire() now support an additional argument X-Git-Tag: rel_0_4_1~20 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=622a26a6551a3580c844b634519ab963c7f35aaf;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - session.refresh() and session.expire() now support an additional argument "attribute_names", a list of individual attribute keynames to be refreshed or expired, allowing partial reloads of attributes on an already-loaded instance. - finally simplified the behavior of deferred attributes, deferred polymorphic load, session.refresh, session.expire, mapper._postfetch to all use a single codepath through query._get(), which now supports a list of individual attribute names to be refreshed. the *one* exception still remaining is mapper._get_poly_select_loader(), which may stay that way since its inline with an already processing load operation. otherwise, query._get() is the single place that all "load this instance's row" operation proceeds. - cleanup all over the place --- diff --git a/CHANGES b/CHANGES index d2a82953eb..7fc24cfd66 100644 --- a/CHANGES +++ b/CHANGES @@ -76,6 +76,11 @@ CHANGES columns to the result set. This eliminates a JOIN from all eager loads with LIMIT/OFFSET. [ticket:843] + - session.refresh() and session.expire() now support an additional argument + "attribute_names", a list of individual attribute keynames to be refreshed + or expired, allowing partial reloads of attributes on an already-loaded + instance. + - Mapped classes may now define __eq__, __hash__, and __nonzero__ methods with arbitrary sementics. The orm now handles all mapped instances on an identity-only basis. (e.g. 'is' vs '==') [ticket:676] diff --git a/doc/build/content/session.txt b/doc/build/content/session.txt index 804769a943..7698a1b5cf 100644 --- a/doc/build/content/session.txt +++ b/doc/build/content/session.txt @@ -396,6 +396,18 @@ To assist with the Session's "sticky" behavior of instances which are present, i session.expire(obj1) session.expire(obj2) +`refresh()` and `expire()` also support being passed a list of individual attribute names in which to be refreshed. These names can reference any attribute, column-based or relation based: + + {python} + # immediately re-load the attributes 'hello', 'world' on obj1, obj2 + session.refresh(obj1, ['hello', 'world']) + session.refresh(obj2, ['hello', 'world']) + + # expire the attriibutes 'hello', 'world' objects obj1, obj2, attributes will be reloaded + # on the next access: + session.expire(obj1, ['hello', 'world']) + session.expire(obj2, ['hello', 'world']) + ## Cascades Mappers support the concept of configurable *cascade* behavior on `relation()`s. This behavior controls how the Session should treat the instances that have a parent-child relationship with another instance that is operated upon by the Session. Cascade is indicated as a comma-separated list of string keywords, with the possible values `all`, `delete`, `save-update`, `refresh-expire`, `merge`, `expunge`, and `delete-orphan`. diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 5e22bc7628..7c760f15a8 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -177,41 +177,16 @@ class AttributeImpl(object): if callable_ is None: self.initialize(state) else: - state.callables[self] = callable_ + state.callables[self.key] = callable_ def _get_callable(self, state): - if self in state.callables: - return state.callables[self] + if self.key in state.callables: + return state.callables[self.key] elif self.callable_ is not None: return self.callable_(state.obj()) else: return None - def reset(self, state): - """Remove any per-instance callable functions corresponding to - this ``InstrumentedAttribute``'s attribute from the given - object, and remove this ``InstrumentedAttribute``'s attribute - from the given object's dictionary. - """ - - try: - del state.callables[self] - except KeyError: - pass - self.clear(state) - - def clear(self, state): - """Remove this ``InstrumentedAttribute``'s attribute from the given object's dictionary. - - Subsequent calls to ``getattr(obj, key)`` will raise an - ``AttributeError`` by default. - """ - - try: - del state.dict[self.key] - except KeyError: - pass - def check_mutable_modified(self, state): return False @@ -232,11 +207,6 @@ class AttributeImpl(object): try: return state.dict[self.key] except KeyError: - # if an instance-wide "trigger" was set, call that - # and start again - if state.trigger: - state.call_trigger() - return self.get(state, passive=passive) callable_ = self._get_callable(state) if callable_ is not None: @@ -246,6 +216,8 @@ class AttributeImpl(object): if value is not ATTR_WAS_SET: return self.set_committed_value(state, value) else: + if self.key not in state.dict: + return self.get(state, passive=passive) return state.dict[self.key] else: # Return a new, empty value @@ -278,10 +250,6 @@ class AttributeImpl(object): state.dict[self.key] = value return value - def set_raw_value(self, state, value): - state.dict[self.key] = value - return value - def fire_append_event(self, state, value, initiator): state.modified = True if self.trackparent and value is not None: @@ -318,7 +286,8 @@ class ScalarAttributeImpl(AttributeImpl): if copy_function is None: copy_function = self.__copy self.copy = copy_function - + self.accepts_global_callable = True + def __copy(self, item): # scalar values are assumed to be immutable unless a copy function # is passed @@ -350,10 +319,6 @@ class ScalarAttributeImpl(AttributeImpl): if initiator is self: return - # if an instance-wide "trigger" was set, call that - if state.trigger: - state.call_trigger() - state.dict[self.key] = value state.modified=True @@ -372,6 +337,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): compare_function=compare_function, mutable_scalars=mutable_scalars, **kwargs) if compare_function is None: self.is_equal = identity_equal + self.accepts_global_callable = False def delete(self, state): old = self.get(state) @@ -389,10 +355,6 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): if initiator is self: return - # if an instance-wide "trigger" was set, call that - if state.trigger: - state.call_trigger() - old = self.get(state) state.dict[self.key] = value self.fire_replace_event(state, value, old, initiator) @@ -416,6 +378,8 @@ class CollectionAttributeImpl(AttributeImpl): copy_function = self.__copy self.copy = copy_function + self.accepts_global_callable = False + if typecallable is None: typecallable = list self.collection_factory = \ @@ -478,10 +442,6 @@ class CollectionAttributeImpl(AttributeImpl): elif setting_type == dict: value = value.values() - # if an instance-wide "trigger" was set, call that - if state.trigger: - state.call_trigger() - old = self.get(state) old_collection = self.get_collection(state, old) @@ -583,13 +543,13 @@ class GenericBackrefExtension(interfaces.AttributeExtension): class InstanceState(object): """tracks state information at the instance level.""" - __slots__ = 'class_', 'obj', 'dict', 'committed_state', 'modified', 'trigger', 'callables', 'parents', 'instance_dict', '_strong_obj' + __slots__ = 'class_', 'obj', 'dict', 'committed_state', 'modified', 'trigger', 'callables', 'parents', 'instance_dict', '_strong_obj', 'expired_attributes' def __init__(self, obj): self.class_ = obj.__class__ self.obj = weakref.ref(obj, self.__cleanup) self.dict = obj.__dict__ - self.committed_state = None + self.committed_state = {} self.modified = False self.trigger = None self.callables = {} @@ -649,36 +609,75 @@ class InstanceState(object): self.dict = self.obj().__dict__ self.callables = {} self.trigger = None + + def initialize(self, key): + getattr(self.class_, key).impl.initialize(self) - def call_trigger(self): - trig = self.trigger - self.trigger = None - trig() + def set_callable(self, key, callable_): + self.dict.pop(key, None) + self.callables[key] = callable_ + + def __fire_trigger(self): + self.trigger(self.obj(), self.expired_attributes) + for k in self.expired_attributes: + self.callables.pop(k, None) + self.expired_attributes.clear() + return ATTR_WAS_SET + + def expire_attributes(self, attribute_names): + if not hasattr(self, 'expired_attributes'): + self.expired_attributes = util.Set() + if attribute_names is None: + for attr in self.class_._sa_attribute_manager.managed_attributes(self.class_): + self.dict.pop(attr.impl.key, None) + self.callables[attr.impl.key] = self.__fire_trigger + self.expired_attributes.add(attr.impl.key) + else: + for key in attribute_names: + self.dict.pop(key, None) + + if not getattr(self.class_, key).impl.accepts_global_callable: + continue + + self.callables[key] = self.__fire_trigger + self.expired_attributes.add(key) + + def reset(self, key): + """remove the given attribute and any callables associated with it.""" + + self.dict.pop(key, None) + self.callables.pop(key, None) + + def clear(self): + """clear all attributes from the instance.""" + + for attr in self.class_._sa_attribute_manager.managed_attributes(self.class_): + self.dict.pop(attr.impl.key, None) + + def commit(self, keys): + """commit all attributes named in the given list of key names. + + This is used by a partial-attribute load operation to mark committed those attributes + which were refreshed from the database. + """ + + for key in keys: + getattr(self.class_, key).impl.commit_to_state(self) + + def commit_all(self): + """commit all attributes unconditionally. + + This is used after a flush() or a regular instance load or refresh operation + to mark committed all populated attributes. + """ - def commit(self, manager, obj): self.committed_state = {} self.modified = False - for attr in manager.managed_attributes(obj.__class__): + for attr in self.class_._sa_attribute_manager.managed_attributes(self.class_): attr.impl.commit_to_state(self) # remove strong ref self._strong_obj = None - def rollback(self, manager, obj): - if not self.committed_state: - manager._clear(obj) - else: - for attr in manager.managed_attributes(obj.__class__): - if attr.impl.key in self.committed_state: - if not hasattr(attr.impl, 'get_collection'): - obj.__dict__[attr.impl.key] = self.committed_state[attr.impl.key] - else: - collection = attr.impl.get_collection(self) - collection.clear_without_event() - for item in self.committed_state[attr.impl.key]: - collection.append_without_event(item) - else: - if attr.impl.key in self.dict: - del self.dict[attr.impl.key] class InstanceDict(UserDict.UserDict): """similar to WeakValueDictionary, but wired towards 'state' objects.""" @@ -878,28 +877,6 @@ class AttributeManager(object): def clear_attribute_cache(self): self._attribute_cache.clear() - def rollback(self, *obj): - """Retrieve the committed history for each object in the given - list, and rolls back the attributes each instance to their - original value. - """ - - for o in obj: - o._state.rollback(self, o) - - def _clear(self, obj): - for attr in self.managed_attributes(obj.__class__): - try: - del obj.__dict__[attr.impl.key] - except KeyError: - pass - - def commit(self, *obj): - """Establish the "committed state" for each object in the given list.""" - - for o in obj: - o._state.commit(self, o) - def managed_attributes(self, class_): """Return a list of all ``InstrumentedAttribute`` objects associated with the given class. @@ -970,60 +947,9 @@ class AttributeManager(object): else: return [x] - def trigger_history(self, obj, callable): - """Clear all managed object attributes and places the given - `callable` as an attribute-wide *trigger*, which will execute - upon the next attribute access, after which the trigger is - removed. - """ - - s = obj._state - self._clear(obj) - s.committed_state = None - s.trigger = callable - - def untrigger_history(self, obj): - """Remove a trigger function set by trigger_history. - - Does not restore the previous state of the object. - """ - - obj._state.trigger = None - - def has_trigger(self, obj): - """Return True if the given object has a trigger function set - by ``trigger_history()``. - """ - - return obj._state.trigger is not None - - def reset_instance_attribute(self, obj, key): - """Remove any per-instance callable functions corresponding to - given attribute `key` from the given object, and remove this - attribute from the given object's dictionary. - """ - - attr = getattr(obj.__class__, key) - attr.impl.reset(obj._state) - - def is_class_managed(self, class_, key): - """Return True if the given `key` correponds to an - instrumented property on the given class. - """ - return hasattr(class_, key) and isinstance(getattr(class_, key), InstrumentedAttribute) - def has_parent(self, class_, obj, key, optimistic=False): return getattr(class_, key).impl.hasparent(obj._state, optimistic=optimistic) - def init_instance_attribute(self, obj, key, callable_=None, clear=False): - """Initialize an attribute on an instance to either a blank - value, cancelling out any class- or instance-level callables - that were present, or if a `callable` is supplied set the - callable to be invoked when the attribute is next accessed. - """ - - getattr(obj.__class__, key).impl.set_callable(obj._state, callable_, clear=clear) - def _create_prop(self, class_, key, uselist, callable_, typecallable, useobject, **kwargs): """Create a scalar property object, defaulting to ``InstrumentedAttribute``, which will communicate change @@ -1135,12 +1061,6 @@ class AttributeManager(object): setattr(class_, key, InstrumentedAttribute(self._create_prop(class_, key, uselist, callable_, useobject=useobject, typecallable=typecallable, **kwargs), comparator=comparator)) - def set_raw_value(self, instance, key, value): - getattr(instance.__class__, key).impl.set_raw_value(instance._state, value) - - def set_committed_value(self, instance, key, value): - getattr(instance.__class__, key).impl.set_committed_value(instance._state, value) - def init_collection(self, instance, key): """Initialize a collection attribute and return the collection adapter.""" attr = getattr(instance.__class__, key).impl diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index f523d0706b..a4a744c473 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1178,7 +1178,7 @@ class Mapper(object): prop = self._getpropbycolumn(c, raiseerror=False) if prop is None: continue - deferred_props.append(prop) + deferred_props.append(prop.key) continue if c.primary_key or not c.key in params: continue @@ -1189,7 +1189,7 @@ class Mapper(object): self.set_attr_by_column(obj, c, params[c.key]) if deferred_props: - deferred_load(obj, props=deferred_props) + expire_instance(obj, deferred_props) def delete_obj(self, objects, uowtransaction): """Issue ``DELETE`` statements for a list of objects. @@ -1342,7 +1342,7 @@ class Mapper(object): return self.__surrogate_mapper or self - def _instance(self, context, row, result = None, skip_polymorphic=False): + def _instance(self, context, row, result=None, skip_polymorphic=False, extension=None, only_load_props=None, refresh_instance=None): """Pull an object instance from the given row and append it to the given result list. @@ -1351,36 +1351,35 @@ class Mapper(object): on the instance to also process extra information in the row. """ - # apply ExtensionOptions applied to the Query to this mapper, - # but only if our mapper matches. - # TODO: what if our mapper inherits from the mapper (i.e. as in a polymorphic load?) - if context.mapper is self: - extension = context.extension - else: + if not extension: extension = self.extension - + if 'translate_row' in extension.methods: ret = extension.translate_row(self, context, row) if ret is not EXT_CONTINUE: row = ret - if not skip_polymorphic and self.polymorphic_on is not None: - discriminator = row[self.polymorphic_on] - if discriminator is not None: - mapper = self.polymorphic_map[discriminator] - if mapper is not self: - if ('polymorphic_fetch', mapper) not in context.attributes: - context.attributes[('polymorphic_fetch', mapper)] = (self, [t for t in mapper.tables if t not in self.tables]) - row = self.translate_row(mapper, row) - return mapper._instance(context, row, result=result, skip_polymorphic=True) + if refresh_instance is None: + if not skip_polymorphic and self.polymorphic_on is not None: + discriminator = row[self.polymorphic_on] + if discriminator is not None: + mapper = self.polymorphic_map[discriminator] + if mapper is not self: + if ('polymorphic_fetch', mapper) not in context.attributes: + context.attributes[('polymorphic_fetch', mapper)] = (self, [t for t in mapper.tables if t not in self.tables]) + row = self.translate_row(mapper, row) + return mapper._instance(context, row, result=result, skip_polymorphic=True) - # look in main identity map. if its there, we dont do anything to it, - # including modifying any of its related items lists, as its already - # been exposed to being modified by the application. - identitykey = self.identity_key_from_row(row) + # determine identity key + if refresh_instance: + identitykey = refresh_instance._instance_key + else: + identitykey = self.identity_key_from_row(row) (session_identity_map, local_identity_map) = (context.session.identity_map, context.identity_map) - + + # look in main identity map. if present, we only populate + # if repopulate flags are set. this block returns the instance. if identitykey in session_identity_map: instance = session_identity_map[identitykey] @@ -1397,19 +1396,21 @@ class Mapper(object): if identitykey not in local_identity_map: local_identity_map[identitykey] = instance isnew = True - if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: - self.populate_instance(context, instance, row, instancekey=identitykey, isnew=isnew) + if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, instancekey=identitykey, isnew=isnew, only_load_props=only_load_props) is EXT_CONTINUE: + self.populate_instance(context, instance, row, instancekey=identitykey, isnew=isnew, only_load_props=only_load_props) if 'append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: if result is not None: result.append(instance) + return instance - else: - if self.__should_log_debug: - self.__log_debug("_instance(): identity key %s not in session" % str(identitykey)) + + elif self.__should_log_debug: + self.__log_debug("_instance(): identity key %s not in session" % str(identitykey)) - # look in result-local identitymap for it. + # look in identity map which is local to this load operation if identitykey not in local_identity_map: + # check that sufficient primary key columns are present if self.allow_null_pks: # check if *all* primary key cols in the result are None - this indicates # an instance of the object is not present in the row. @@ -1424,7 +1425,6 @@ class Mapper(object): if None in identitykey[1]: return None - # plugin point if 'create_instance' in extension.methods: instance = extension.create_instance(self, context, row, self.class_) if instance is EXT_CONTINUE: @@ -1433,24 +1433,26 @@ class Mapper(object): instance = attribute_manager.new_instance(self.class_) instance._entity_name = self.entity_name + instance._instance_key = identitykey + if self.__should_log_debug: self.__log_debug("_instance(): created new instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey))) + local_identity_map[identitykey] = instance isnew = True else: + # instance is already present instance = local_identity_map[identitykey] isnew = False - # call further mapper properties on the row, to pull further - # instances from the row and possibly populate this item. + # populate. note that we still call this for an instance already loaded as additional collection state is present + # in subsequent rows (i.e. eagerly loaded collections) flags = {'instancekey':identitykey, 'isnew':isnew} - if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, **flags) is EXT_CONTINUE: - self.populate_instance(context, instance, row, **flags) + if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, **flags) is EXT_CONTINUE: + self.populate_instance(context, instance, row, only_load_props=only_load_props, **flags) if 'append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, **flags) is EXT_CONTINUE: if result is not None: result.append(instance) - - instance._instance_key = identitykey return instance @@ -1487,13 +1489,14 @@ class Mapper(object): """ if tomapper in self._row_translators: + # row translators are cached based on target mapper return self._row_translators[tomapper](row) else: translator = create_row_adapter(self.mapped_table, tomapper.mapped_table, equivalent_columns=self._equivalent_columns) self._row_translators[tomapper] = translator return translator(row) - def populate_instance(self, selectcontext, instance, row, ispostselect=None, isnew=False, **flags): + def populate_instance(self, selectcontext, instance, row, ispostselect=None, isnew=False, only_load_props=None, **flags): """populate an instance from a result row.""" snapshot = selectcontext.path + (self,) @@ -1511,6 +1514,8 @@ class Mapper(object): existing_populators = [] post_processors = [] for prop in self.__props.values(): + if only_load_props and prop.key not in only_load_props: + continue (newpop, existingpop, post_proc) = selectcontext.exec_with_path(self, prop.key, prop.create_row_processor, selectcontext, self, row) if newpop is not None: new_populators.append((prop.key, newpop)) @@ -1518,7 +1523,8 @@ class Mapper(object): existing_populators.append((prop.key, existingpop)) if post_proc is not None: post_processors.append(post_proc) - + + # install a post processor for immediate post-load of joined-table inheriting mappers poly_select_loader = self._get_poly_select_loader(selectcontext, row) if poly_select_loader is not None: post_processors.append(poly_select_loader) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index f2a60315c1..df52512eed 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -599,7 +599,7 @@ class PropertyLoader(StrategizedProperty): if self.backref is not None: self.backref.compile(self) - elif not sessionlib.attribute_manager.is_class_managed(self.parent.class_, self.key): + elif not mapper.class_mapper(self.parent.class_).get_property(self.key, raiseerr=False): raise exceptions.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'. New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (self.key, self.parent.class_.__name__, self.parent.class_.__name__)) super(PropertyLoader, self).do_init() @@ -702,28 +702,5 @@ class BackRef(object): return attributes.GenericBackrefExtension(self.key) -def deferred_load(instance, props): - """set multiple instance attributes to 'deferred' or 'lazy' load, for the given set of MapperProperty objects. - - this will remove the current value of the attribute and set a per-instance - callable to fire off when the instance is next accessed. - - for column-based properties, aggreagtes them into a single list against a single deferred loader - so that a single column access loads all columns - - """ - - if not props: - return - column_props = [p for p in props if isinstance(p, ColumnProperty)] - callable_ = column_props[0]._get_strategy(strategies.DeferredColumnLoader).setup_loader(instance, props=column_props) - for p in column_props: - sessionlib.attribute_manager.init_instance_attribute(instance, p.key, callable_=callable_, clear=True) - - for p in [p for p in props if isinstance(p, PropertyLoader)]: - callable_ = p._get_strategy(strategies.LazyLoader).setup_loader(instance) - sessionlib.attribute_manager.init_instance_attribute(instance, p.key, callable_=callable_, clear=True) - mapper.ColumnProperty = ColumnProperty -mapper.deferred_load = deferred_load diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index a4460ea254..b828750d4a 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -50,6 +50,8 @@ class Query(object): self._attributes = {} self._current_path = () self._primary_adapter=None + self._only_load_props = None + self._refresh_instance = None def _clone(self): q = Query.__new__(Query) @@ -103,12 +105,14 @@ class Query(object): key column values in the order of the table def's primary key columns. """ - + print "LOAD CHECK1" ret = self._extension.load(self, ident, **kwargs) if ret is not mapper.EXT_CONTINUE: return ret + print "LOAD CHECK2" key = self.mapper.identity_key_from_primary_key(ident) - instance = self._get(key, ident, reload=True, **kwargs) + instance = self.populate_existing()._get(key, ident, **kwargs) + print "LOAD CHECK3" if instance is None and raiseerr: raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident)) return instance @@ -655,7 +659,7 @@ class Query(object): return self._execute_and_instances(context) def _execute_and_instances(self, querycontext): - result = self.session.execute(querycontext.statement, params=self._params, mapper=self.mapper) + result = self.session.execute(querycontext.statement, params=self._params, mapper=self.mapper, instance=self._refresh_instance) try: return iter(self.instances(result, querycontext=querycontext)) finally: @@ -670,8 +674,6 @@ class Query(object): and add_column(). """ - self.__log_debug("instances()") - session = self.session context = kwargs.pop('querycontext', None) @@ -713,54 +715,68 @@ class Query(object): result = [] else: result = util.UniqueAppender([]) - + + primary_mapper_args = dict(extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance) + for row in cursor.fetchall(): if self._primary_adapter: - self.select_mapper._instance(context, self._primary_adapter(row), result) + self.select_mapper._instance(context, self._primary_adapter(row), result, **primary_mapper_args) else: - self.select_mapper._instance(context, row, result) + self.select_mapper._instance(context, row, result, **primary_mapper_args) for proc in process: proc[0](context, row) for instance in context.identity_map.values(): context.attributes.get(('populating_mapper', id(instance)), object_mapper(instance))._post_instance(context, instance) - + + if context.refresh_instance and context.only_load_props and context.refresh_instance._instance_key in context.identity_map: + # if refreshing partial instance, do special state commit + # affecting only the refreshed attributes + context.refresh_instance._state.commit(context.only_load_props) + del context.identity_map[context.refresh_instance._instance_key] + # store new stuff in the identity map for instance in context.identity_map.values(): session._register_persistent(instance) - + if mappers_or_columns: return list(util.OrderedSet(zip(*([result] + [o[1] for o in process])))) else: return result.data - def _get(self, key, ident=None, reload=False, lockmode=None): + def _get(self, key=None, ident=None, refresh_instance=None, lockmode=None, only_load_props=None): lockmode = lockmode or self._lockmode - if not reload and not self.mapper.always_refresh and lockmode is None: + if not self._populate_existing and not refresh_instance and not self.mapper.always_refresh and lockmode is None: try: return self.session.identity_map[key] except KeyError: pass - + if ident is None: - ident = key[1] + if key is not None: + ident = key[1] else: ident = util.to_list(ident) - params = {} + + q = self - (_get_clause, _get_params) = self.select_mapper._get_clause - for i, primary_key in enumerate(self.primary_key_columns): - try: - params[_get_params[primary_key].key] = ident[i] - except IndexError: - raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in self.primary_key_columns])) + if ident is not None: + params = {} + (_get_clause, _get_params) = self.select_mapper._get_clause + q = q.filter(_get_clause) + for i, primary_key in enumerate(self.primary_key_columns): + try: + params[_get_params[primary_key].key] = ident[i] + except IndexError: + raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in self.primary_key_columns])) + q = q.params(params) + try: - q = self if lockmode is not None: q = q.with_lockmode(lockmode) - q = q.filter(_get_clause) - q = q.params(params)._select_context_options(populate_existing=reload, version_check=(lockmode is not None)) + q = q._select_context_options(populate_existing=refresh_instance is not None, version_check=(lockmode is not None), only_load_props=only_load_props, refresh_instance=refresh_instance) + q = q.order_by(None) # call using all() to avoid LIMIT compilation complexity return q.all()[0] except IndexError: @@ -861,7 +877,9 @@ class Query(object): # TODO: doing this off the select_mapper. if its the polymorphic mapper, then # it has no relations() on it. should we compile those too into the query ? (i.e. eagerloads) for value in self.select_mapper.iterate_properties: - context.exec_with_path(self.select_mapper, value.key, value.setup, context) + if self._only_load_props and value.key not in self._only_load_props: + continue + context.exec_with_path(self.select_mapper, value.key, value.setup, context, only_load_props=self._only_load_props) # additional entities/columns, add those to selection criterion for tup in self._entities: @@ -912,7 +930,6 @@ class Query(object): statement.append_order_by(*context.eager_order_by) else: statement = sql.select(context.primary_columns + context.secondary_columns, whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **self._select_args()) - if context.eager_joins: statement.append_from(context.eager_joins, _copy_collection=False) @@ -1101,11 +1118,15 @@ class Query(object): q._select_context_options(**kwargs) return list(q) - def _select_context_options(self, populate_existing=None, version_check=None): #pragma: no cover - if populate_existing is not None: + def _select_context_options(self, populate_existing=None, version_check=None, only_load_props=None, refresh_instance=None): #pragma: no cover + if populate_existing: self._populate_existing = populate_existing - if version_check is not None: + if version_check: self._version_check = version_check + if refresh_instance is not None: + self._refresh_instance = refresh_instance + if only_load_props: + self._only_load_props = util.Set(only_load_props) return self def join_to(self, key): #pragma: no cover @@ -1207,6 +1228,8 @@ class QueryContext(object): self.statement = None self.populate_existing = query._populate_existing self.version_check = query._version_check + self.only_load_props = query._only_load_props + self.refresh_instance = query._refresh_instance self.identity_map = {} self.path = () self.primary_columns = [] diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 0b57447788..7cdc7f6719 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -6,7 +6,7 @@ import weakref from sqlalchemy import util, exceptions, sql, engine -from sqlalchemy.orm import unitofwork, query, util as mapperutil +from sqlalchemy.orm import unitofwork, query, attributes, util as mapperutil from sqlalchemy.orm.mapper import object_mapper as _object_mapper from sqlalchemy.orm.mapper import class_mapper as _class_mapper from sqlalchemy.orm.mapper import Mapper @@ -522,9 +522,9 @@ class Session(object): resources of the underlying ``Connection``. """ - engine = self.get_bind(mapper, clause=clause) + engine = self.get_bind(mapper, clause=clause, **kwargs) - return self.__connection(engine, close_with_result=True).execute(clause, params or {}, **kwargs) + return self.__connection(engine, close_with_result=True).execute(clause, params or {}) def scalar(self, clause, params=None, mapper=None, **kwargs): """Like execute() but return a scalar result.""" @@ -716,26 +716,57 @@ class Session(object): entity_name = kwargs.pop('entity_name', None) return self.query(class_, entity_name=entity_name).load(ident, **kwargs) - def refresh(self, obj): - """Reload the attributes for the given object from the - database, clear any changes made. + def refresh(self, obj, attribute_names=None): + """Refresh the attributes on the given instance. + + When called, a query will be issued + to the database which will refresh all attributes with their + current value. + + Lazy-loaded relational attributes will remain lazily loaded, so that + the instance-wide refresh operation will be followed + immediately by the lazy load of that attribute. + + Eagerly-loaded relational attributes will eagerly load within the + single refresh operation. + + The ``attribute_names`` argument is an iterable collection + of attribute names indicating a subset of attributes to be + refreshed. """ self._validate_persistent(obj) - if self.query(obj.__class__)._get(obj._instance_key, reload=True) is None: + + if self.query(obj.__class__)._get(obj._instance_key, refresh_instance=obj, only_load_props=attribute_names) is None: raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(obj)) - def expire(self, obj): - """Mark the given object as expired. - - This will add an instrumentation to all mapped attributes on - the instance such that when an attribute is next accessed, the - session will reload all attributes on the instance from the - database. + def expire(self, obj, attribute_names=None): + """Expire the attributes on the given instance. + + The instance's attributes are instrumented such that + when an attribute is next accessed, a query will be issued + to the database which will refresh all attributes with their + current value. + + Lazy-loaded relational attributes will remain lazily loaded, so that + triggering one will incur the instance-wide refresh operation, followed + immediately by the lazy load of that attribute. + + Eagerly-loaded relational attributes will eagerly load within the + single refresh operation. + + The ``attribute_names`` argument is an iterable collection + of attribute names indicating a subset of attributes to be + expired. """ - - for c in [obj] + list(_object_mapper(obj).cascade_iterator('refresh-expire', obj)): - self._expire_impl(c) + + if attribute_names: + self._validate_persistent(obj) + expire_instance(obj, attribute_names=attribute_names) + else: + for c in [obj] + list(_object_mapper(obj).cascade_iterator('refresh-expire', obj)): + self._validate_persistent(obj) + expire_instance(c, None) def prune(self): """Removes unreferenced instances cached in the identity map. @@ -750,21 +781,12 @@ class Session(object): return self.uow.prune_identity_map() - def _expire_impl(self, obj): - self._validate_persistent(obj) - - def exp(): - if self.query(obj.__class__)._get(obj._instance_key, reload=True) is None: - raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(obj)) - - attribute_manager.trigger_history(obj, exp) - def is_expired(self, obj, unexpire=False): """Return True if the given object has been marked as expired.""" - ret = attribute_manager.has_trigger(obj) + ret = obj._state.trigger is not None if ret and unexpire: - attribute_manager.untrigger_history(obj) + obj._state.trigger = None return ret def expunge(self, object): @@ -999,7 +1021,7 @@ class Session(object): def _register_persistent(self, obj): obj._sa_session_id = self.hash_key self.identity_map[obj._instance_key] = obj - attribute_manager.commit(obj) + obj._state.commit_all() def _attach(self, obj): old_id = getattr(obj, '_sa_session_id', None) @@ -1083,6 +1105,25 @@ class Session(object): new = property(lambda s:s.uow.new, doc="A ``Set`` of all objects marked as 'new' within this ``Session``.") +def expire_instance(obj, attribute_names): + """standalone expire instance function. + + installs a callable with the given instance's _state + which will fire off when any of the named attributes are accessed; + their existing value is removed. + + If the list is None or blank, the entire instance is expired. + """ + + if obj._state.trigger is None: + def load_attributes(instance, attribute_names): + if object_session(instance).query(instance.__class__)._get(instance._instance_key, refresh_instance=instance, only_load_props=attribute_names) is None: + raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance)) + obj._state.trigger = load_attributes + + obj._state.expire_attributes(attribute_names) + + # this is the AttributeManager instance used to provide attribute behavior on objects. # to all the "global variable police" out there: its a stateless object. @@ -1108,3 +1149,4 @@ def object_session(obj): unitofwork.object_session = object_session from sqlalchemy.orm import mapper mapper.attribute_manager = attribute_manager +mapper.expire_instance = expire_instance \ No newline at end of file diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 8574d2fef2..dfd1efa36e 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -114,8 +114,7 @@ class ColumnLoader(LoaderStrategy): def new_execute(instance, row, isnew, **flags): if isnew: - loader = strategy.setup_loader(instance, props=props, create_statement=create_statement) - sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=loader) + instance._state.set_callable(self.key, strategy.setup_loader(instance, props=props, create_statement=create_statement)) if self._should_log_debug: self.logger.debug("Returning deferred column fetcher for %s %s" % (mapper, self.key)) @@ -134,19 +133,19 @@ class DeferredColumnLoader(LoaderStrategy): """Deferred column loader, a per-column or per-column-group lazy loader.""" def create_row_processor(self, selectcontext, mapper, row): - if (self.group is not None and selectcontext.attributes.get(('undefer', self.group), False)) or self.columns[0] in row: + if self.columns[0] in row: return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, mapper, row) elif not self.is_class_level or len(selectcontext.options): def new_execute(instance, row, **flags): if self._should_log_debug: self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key)) - sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self.setup_loader(instance)) + instance._state.set_callable(self.key, self.setup_loader(instance)) return (new_execute, None, None) else: def new_execute(instance, row, **flags): if self._should_log_debug: self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key)) - sessionlib.attribute_manager.reset_instance_attribute(instance, self.key) + instance._state.reset(self.key) return (new_execute, None, None) def init(self): @@ -162,8 +161,11 @@ class DeferredColumnLoader(LoaderStrategy): self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__)) sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, callable_=self.setup_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator) - def setup_query(self, context, **kwargs): - if self.group is not None and context.attributes.get(('undefer', self.group), False): + def setup_query(self, context, only_load_props=None, **kwargs): + if \ + (self.group is not None and context.attributes.get(('undefer', self.group), False)) or \ + (only_load_props and self.key in only_load_props): + self.parent_property._get_strategy(ColumnLoader).setup_query(context, **kwargs) def setup_loader(self, instance, props=None, create_statement=None): @@ -198,27 +200,12 @@ class DeferredColumnLoader(LoaderStrategy): raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key)) if create_statement is None: - (clause, param_map) = localparent._get_clause ident = instance._instance_key[1] - params = {} - for i, primary_key in enumerate(localparent.primary_key): - params[param_map[primary_key].key] = ident[i] - statement = sql.select([p.columns[0] for p in group], clause, from_obj=[localparent.mapped_table], use_labels=True) + session.query(localparent)._get(None, ident=ident, only_load_props=[p.key for p in group], refresh_instance=instance) else: statement, params = create_statement(instance) - - # TODO: have the "fetch of one row" operation go through the same channels as a query._get() - # deferred load of several attributes should be a specialized case of a query refresh operation - conn = session.connection(mapper=localparent, instance=instance) - result = conn.execute(statement, params) - try: - row = result.fetchone() - for prop in group: - sessionlib.attribute_manager.set_committed_value(instance, prop.key, row[prop.columns[0]]) - return attributes.ATTR_WAS_SET - finally: - result.close() - + session.query(localparent).from_statement(statement).params(params)._get(None, only_load_props=[p.key for p in group], refresh_instance=instance) + return attributes.ATTR_WAS_SET return lazyload DeferredColumnLoader.logger = logging.class_logger(DeferredColumnLoader) @@ -248,7 +235,10 @@ class AbstractRelationLoader(LoaderStrategy): self._should_log_debug = logging.is_debug_enabled(self.logger) def _init_instance_attribute(self, instance, callable_=None): - return sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=callable_) + if callable_: + instance._state.set_callable(self.key, callable_) + else: + instance._state.initialize(self.key) def _register_attribute(self, class_, callable_=None, **kwargs): self.logger.info("register managed %s attribute %s on class %s" % ((self.uselist and "list-holding" or "scalar"), self.key, self.parent.class_.__name__)) @@ -399,7 +389,7 @@ class LazyLoader(AbstractRelationLoader): # so that the class-level lazy loader is executed when next referenced on this instance. # this usually is not needed unless the constructor of the object referenced the attribute before we got # to load data into it. - sessionlib.attribute_manager.reset_instance_attribute(instance, self.key) + instance._state.reset(self.key) return (new_execute, None, None) def _create_lazy_clause(cls, prop, reverse_direction=False): @@ -603,10 +593,7 @@ class EagerLoader(AbstractRelationLoader): # parent object, bypassing InstrumentedAttribute # event handlers. # - # FIXME: instead of... - sessionlib.attribute_manager.set_raw_value(instance, self.key, self.select_mapper._instance(selectcontext, decorated_row, None)) - # bypass and set directly: - #instance.__dict__[self.key] = self.select_mapper._instance(selectcontext, decorated_row, None) + instance.__dict__[self.key] = self.select_mapper._instance(selectcontext, decorated_row, None) else: # call _instance on the row, even though the object has been created, # so that we further descend into properties diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 7f9a4d7d06..2cd7cb6f5d 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -128,7 +128,7 @@ class UnitOfWork(object): if hasattr(obj, '_sa_insert_order'): delattr(obj, '_sa_insert_order') self.identity_map[obj._instance_key] = obj - attribute_manager.commit(obj) + obj._state.commit_all() def register_new(self, obj): """register the given object as 'new' (i.e. unsaved) within this unit of work.""" diff --git a/test/orm/alltests.py b/test/orm/alltests.py index 59357c7b71..059d7a100e 100644 --- a/test/orm/alltests.py +++ b/test/orm/alltests.py @@ -11,6 +11,7 @@ def suite(): 'orm.lazy_relations', 'orm.eager_relations', 'orm.mapper', + 'orm.expire', 'orm.selectable', 'orm.collection', 'orm.generative', diff --git a/test/orm/attributes.py b/test/orm/attributes.py index 930bfa57e2..2080474edd 100644 --- a/test/orm/attributes.py +++ b/test/orm/attributes.py @@ -5,6 +5,8 @@ from sqlalchemy.orm.collections import collection from sqlalchemy import exceptions from testlib import * +ROLLBACK_SUPPORTED=False + # these test classes defined at the module # level to support pickling class MyTest(object):pass @@ -29,7 +31,7 @@ class AttributesTest(PersistTest): print repr(u.__dict__) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') - manager.commit(u) + u._state.commit_all() print repr(u.__dict__) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') @@ -37,10 +39,11 @@ class AttributesTest(PersistTest): u.email_address = 'foo@bar.com' print repr(u.__dict__) self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.email_address == 'foo@bar.com') - - manager.rollback(u) - print repr(u.__dict__) - self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') + + if ROLLBACK_SUPPORTED: + manager.rollback(u) + print repr(u.__dict__) + self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') def test_pickleness(self): @@ -128,7 +131,7 @@ class AttributesTest(PersistTest): print repr(u.__dict__) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') - manager.commit(u, a) + u, a._state.commit_all() print repr(u.__dict__) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') @@ -140,11 +143,12 @@ class AttributesTest(PersistTest): print repr(u.__dict__) self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com') - manager.rollback(u, a) - print repr(u.__dict__) - print repr(u.addresses[0].__dict__) - self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') - self.assert_(len(manager.get_history(u, 'addresses').unchanged_items()) == 1) + if ROLLBACK_SUPPORTED: + manager.rollback(u, a) + print repr(u.__dict__) + print repr(u.addresses[0].__dict__) + self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') + self.assert_(len(manager.get_history(u, 'addresses').unchanged_items()) == 1) def test_backref(self): class Student(object):pass @@ -231,9 +235,9 @@ class AttributesTest(PersistTest): # create objects as if they'd been freshly loaded from the database (without history) b = Blog() p1 = Post() - manager.init_instance_attribute(b, 'posts', lambda:[p1]) - manager.init_instance_attribute(p1, 'blog', lambda:b) - manager.commit(p1, b) + b._state.set_callable('posts', lambda:[p1]) + p1._state.set_callable('blog', lambda:b) + p1, b._state.commit_all() # no orphans (called before the lazy loaders fire off) assert manager.has_parent(Blog, p1, 'posts', optimistic=True) @@ -292,7 +296,7 @@ class AttributesTest(PersistTest): x.element = 'this is the element' hist = manager.get_history(x, 'element') assert hist.added_items() == ['this is the element'] - manager.commit(x) + x._state.commit_all() hist = manager.get_history(x, 'element') assert hist.added_items() == [] assert hist.unchanged_items() == ['this is the element'] @@ -320,7 +324,7 @@ class AttributesTest(PersistTest): manager.register_attribute(Bar, 'id', uselist=False, useobject=True) x = Foo() - manager.commit(x) + x._state.commit_all() x.col2.append(Bar(4)) h = manager.get_history(x, 'col2') print h.added_items() @@ -362,7 +366,7 @@ class AttributesTest(PersistTest): manager.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x], mutable_scalars=True, useobject=False) x = Foo() x.element = ['one', 'two', 'three'] - manager.commit(x) + x._state.commit_all() x.element[1] = 'five' assert manager.is_modified(x) @@ -372,7 +376,7 @@ class AttributesTest(PersistTest): manager.register_attribute(Foo, 'element', uselist=False, useobject=False) x = Foo() x.element = ['one', 'two', 'three'] - manager.commit(x) + x._state.commit_all() x.element[1] = 'five' assert not manager.is_modified(x) diff --git a/test/orm/dynamic.py b/test/orm/dynamic.py index 80e3982da2..0bb253b198 100644 --- a/test/orm/dynamic.py +++ b/test/orm/dynamic.py @@ -7,12 +7,10 @@ from testlib.fixtures import * from query import QueryTest -class DynamicTest(QueryTest): +class DynamicTest(FixtureTest): keep_mappers = False - - def setup_mappers(self): - pass - + keep_data = True + def test_basic(self): mapper(User, users, properties={ 'addresses':dynamic_loader(mapper(Address, addresses)) diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py index 3e811b86bf..52602ecae3 100644 --- a/test/orm/eager_relations.py +++ b/test/orm/eager_relations.py @@ -7,9 +7,10 @@ from testlib import * from testlib.fixtures import * from query import QueryTest -class EagerTest(QueryTest): +class EagerTest(FixtureTest): keep_mappers = False - + keep_data = True + def setup_mappers(self): pass diff --git a/test/orm/expire.py b/test/orm/expire.py new file mode 100644 index 0000000000..430117725c --- /dev/null +++ b/test/orm/expire.py @@ -0,0 +1,439 @@ +"""test attribute/instance expiration, deferral of attributes, etc.""" + +import testbase +from sqlalchemy import * +from sqlalchemy import exceptions +from sqlalchemy.orm import * +from testlib import * +from testlib.fixtures import * + +class ExpireTest(FixtureTest): + keep_mappers = False + refresh_data = True + + def test_expire(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user'), + }) + mapper(Address, addresses) + + sess = create_session() + u = sess.query(User).get(7) + assert len(u.addresses) == 1 + u.name = 'foo' + del u.addresses[0] + sess.expire(u) + + assert 'name' not in u.__dict__ + + def go(): + assert u.name == 'jack' + self.assert_sql_count(testbase.db, go, 1) + assert 'name' in u.__dict__ + + # we're changing the database here, so if this test fails in the middle, + # it'll screw up the other tests which are hardcoded to 7/'jack' + u.name = 'foo' + sess.flush() + # change the value in the DB + users.update(users.c.id==7, values=dict(name='jack')).execute() + sess.expire(u) + # object isnt refreshed yet, using dict to bypass trigger + assert u.__dict__.get('name') != 'jack' + # reload all + sess.query(User).all() + # test that it refreshed + assert u.__dict__['name'] == 'jack' + + # object should be back to normal now, + # this should *not* produce a SELECT statement (not tested here though....) + assert u.name == 'jack' + + def test_expire_doesntload_on_set(self): + mapper(User, users) + + sess = create_session() + u = sess.query(User).get(7) + + sess.expire(u, attribute_names=['name']) + def go(): + u.name = 'somenewname' + self.assert_sql_count(testbase.db, go, 0) + sess.flush() + sess.clear() + assert sess.query(User).get(7).name == 'somenewname' + + def test_expire_committed(self): + """test that the committed state of the attribute receives the most recent DB data""" + mapper(Order, orders) + + sess = create_session() + o = sess.query(Order).get(3) + sess.expire(o) + + orders.update(id=3).execute(description='order 3 modified') + assert o.isopen == 1 + assert o._state.committed_state['description'] == 'order 3 modified' + def go(): + sess.flush() + self.assert_sql_count(testbase.db, go, 0) + + def test_expire_cascade(self): + mapper(User, users, properties={ + 'addresses':relation(Address, cascade="all, refresh-expire") + }) + mapper(Address, addresses) + s = create_session() + u = s.get(User, 8) + assert u.addresses[0].email_address == 'ed@wood.com' + + u.addresses[0].email_address = 'someotheraddress' + s.expire(u) + u.name + print u._state.dict + assert u.addresses[0].email_address == 'ed@wood.com' + + def test_expired_lazy(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user'), + }) + mapper(Address, addresses) + + sess = create_session() + u = sess.query(User).get(7) + + sess.expire(u) + assert 'name' not in u.__dict__ + assert 'addresses' not in u.__dict__ + + def go(): + assert u.addresses[0].email_address == 'jack@bean.com' + assert u.name == 'jack' + # two loads + self.assert_sql_count(testbase.db, go, 2) + assert 'name' in u.__dict__ + assert 'addresses' in u.__dict__ + + def test_expired_eager(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', lazy=False), + }) + mapper(Address, addresses) + + sess = create_session() + u = sess.query(User).get(7) + + sess.expire(u) + assert 'name' not in u.__dict__ + assert 'addresses' not in u.__dict__ + + def go(): + assert u.addresses[0].email_address == 'jack@bean.com' + assert u.name == 'jack' + # one load + self.assert_sql_count(testbase.db, go, 1) + assert 'name' in u.__dict__ + assert 'addresses' in u.__dict__ + + def test_partial_expire(self): + mapper(Order, orders) + + sess = create_session() + o = sess.query(Order).get(3) + + sess.expire(o, attribute_names=['description']) + assert 'id' in o.__dict__ + assert 'description' not in o.__dict__ + assert o._state.committed_state['isopen'] == 1 + + orders.update(orders.c.id==3).execute(description='order 3 modified') + + def go(): + assert o.description == 'order 3 modified' + self.assert_sql_count(testbase.db, go, 1) + assert o._state.committed_state['description'] == 'order 3 modified' + + o.isopen = 5 + sess.expire(o, attribute_names=['description']) + assert 'id' in o.__dict__ + assert 'description' not in o.__dict__ + assert o.__dict__['isopen'] == 5 + assert o._state.committed_state['isopen'] == 1 + + def go(): + assert o.description == 'order 3 modified' + self.assert_sql_count(testbase.db, go, 1) + assert o.__dict__['isopen'] == 5 + assert o._state.committed_state['description'] == 'order 3 modified' + assert o._state.committed_state['isopen'] == 1 + + sess.flush() + + sess.expire(o, attribute_names=['id', 'isopen', 'description']) + assert 'id' not in o.__dict__ + assert 'isopen' not in o.__dict__ + assert 'description' not in o.__dict__ + def go(): + assert o.description == 'order 3 modified' + assert o.id == 3 + assert o.isopen == 5 + self.assert_sql_count(testbase.db, go, 1) + + def test_partial_expire_lazy(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user'), + }) + mapper(Address, addresses) + + sess = create_session() + u = sess.query(User).get(8) + + sess.expire(u, ['name', 'addresses']) + assert 'name' not in u.__dict__ + assert 'addresses' not in u.__dict__ + + # hit the lazy loader. just does the lazy load, + # doesnt do the overall refresh + def go(): + assert u.addresses[0].email_address=='ed@wood.com' + self.assert_sql_count(testbase.db, go, 1) + + assert 'name' not in u.__dict__ + + # check that mods to expired lazy-load attributes + # only do the lazy load + sess.expire(u, ['name', 'addresses']) + def go(): + u.addresses = [Address(id=10, email_address='foo@bar.com')] + self.assert_sql_count(testbase.db, go, 1) + + sess.flush() + + # flush has occurred, and addresses was modified, + # so the addresses collection got committed and is + # longer expired + def go(): + assert u.addresses[0].email_address=='foo@bar.com' + assert len(u.addresses) == 1 + self.assert_sql_count(testbase.db, go, 0) + + # but the name attribute was never loaded and so + # still loads + def go(): + assert u.name == 'ed' + self.assert_sql_count(testbase.db, go, 1) + + def test_partial_expire_eager(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', lazy=False), + }) + mapper(Address, addresses) + + sess = create_session() + u = sess.query(User).get(8) + + sess.expire(u, ['name', 'addresses']) + assert 'name' not in u.__dict__ + assert 'addresses' not in u.__dict__ + + def go(): + assert u.addresses[0].email_address=='ed@wood.com' + self.assert_sql_count(testbase.db, go, 1) + + # check that mods to expired eager-load attributes + # do the refresh + sess.expire(u, ['name', 'addresses']) + def go(): + u.addresses = [Address(id=10, email_address='foo@bar.com')] + self.assert_sql_count(testbase.db, go, 1) + sess.flush() + + # this should ideally trigger the whole load + # but currently it works like the lazy case + def go(): + assert u.addresses[0].email_address=='foo@bar.com' + assert len(u.addresses) == 1 + self.assert_sql_count(testbase.db, go, 0) + + def go(): + assert u.name == 'ed' + # scalar attributes have their own load + self.assert_sql_count(testbase.db, go, 1) + # ideally, this was already loaded, but we arent + # doing it that way right now + #self.assert_sql_count(testbase.db, go, 0) + + def test_partial_expire_deferred(self): + mapper(Order, orders, properties={ + 'description':deferred(orders.c.description) + }) + + sess = create_session() + o = sess.query(Order).get(3) + sess.expire(o, ['description', 'isopen']) + assert 'isopen' not in o.__dict__ + assert 'description' not in o.__dict__ + + # test that expired attribute access refreshes + # the deferred + def go(): + assert o.isopen == 1 + assert o.description == 'order 3' + self.assert_sql_count(testbase.db, go, 1) + + sess.expire(o, ['description', 'isopen']) + assert 'isopen' not in o.__dict__ + assert 'description' not in o.__dict__ + # test that the deferred attribute triggers the full + # reload + def go(): + assert o.description == 'order 3' + assert o.isopen == 1 + self.assert_sql_count(testbase.db, go, 1) + + clear_mappers() + + mapper(Order, orders) + sess.clear() + + # same tests, using deferred at the options level + o = sess.query(Order).options(defer('description')).get(3) + + assert 'description' not in o.__dict__ + + # sanity check + def go(): + assert o.description == 'order 3' + self.assert_sql_count(testbase.db, go, 1) + + assert 'description' in o.__dict__ + assert 'isopen' in o.__dict__ + sess.expire(o, ['description', 'isopen']) + assert 'isopen' not in o.__dict__ + assert 'description' not in o.__dict__ + + # test that expired attribute access refreshes + # the deferred + def go(): + assert o.isopen == 1 + assert o.description == 'order 3' + self.assert_sql_count(testbase.db, go, 1) + sess.expire(o, ['description', 'isopen']) + + assert 'isopen' not in o.__dict__ + assert 'description' not in o.__dict__ + # test that the deferred attribute triggers the full + # reload + def go(): + assert o.description == 'order 3' + assert o.isopen == 1 + self.assert_sql_count(testbase.db, go, 1) + + +class RefreshTest(FixtureTest): + keep_mappers = False + refresh_data = True + + def test_refresh(self): + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), backref='user') + }) + s = create_session() + u = s.get(User, 7) + u.name = 'foo' + a = Address() + assert object_session(a) is None + u.addresses.append(a) + assert a.email_address is None + assert id(a) in [id(x) for x in u.addresses] + + s.refresh(u) + + # its refreshed, so not dirty + assert u not in s.dirty + + # username is back to the DB + assert u.name == 'jack' + + assert id(a) not in [id(x) for x in u.addresses] + + u.name = 'foo' + u.addresses.append(a) + # now its dirty + assert u in s.dirty + assert u.name == 'foo' + assert id(a) in [id(x) for x in u.addresses] + s.expire(u) + + # get the attribute, it refreshes + assert u.name == 'jack' + assert id(a) not in [id(x) for x in u.addresses] + + def test_refresh_expired(self): + mapper(User, users) + s = create_session() + u = s.get(User, 7) + s.expire(u) + assert 'name' not in u.__dict__ + s.refresh(u) + assert u.name == 'jack' + + def test_refresh_with_lazy(self): + """test that when a lazy loader is set as a trigger on an object's attribute + (at the attribute level, not the class level), a refresh() operation doesnt + fire the lazy loader or create any problems""" + + s = create_session() + mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))}) + q = s.query(User).options(lazyload('addresses')) + u = q.filter(users.c.id==8).first() + def go(): + s.refresh(u) + self.assert_sql_count(testbase.db, go, 1) + + + def test_refresh_with_eager(self): + """test that a refresh/expire operation loads rows properly and sends correct "isnew" state to eager loaders""" + + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), lazy=False) + }) + + s = create_session() + u = s.get(User, 8) + assert len(u.addresses) == 3 + s.refresh(u) + assert len(u.addresses) == 3 + + s = create_session() + u = s.get(User, 8) + assert len(u.addresses) == 3 + s.expire(u) + assert len(u.addresses) == 3 + + @testing.fails_on('maxdb') + def test_refresh2(self): + """test a hang condition that was occuring on expire/refresh""" + + s = create_session() + mapper(Address, addresses) + + mapper(User, users, properties = dict(addresses=relation(Address,cascade="all, delete-orphan",lazy=False)) ) + + u=User() + u.name='Justin' + a = Address(id=10, email_address='lala') + u.addresses.append(a) + + s.save(u) + s.flush() + s.clear() + u = s.query(User).filter(User.name=='Justin').one() + + s.expire(u) + assert u.name == 'Justin' + + s.refresh(u) + +if __name__ == '__main__': + testbase.main() diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py index 9fa7fffbac..32420300f2 100644 --- a/test/orm/inheritance/basic.py +++ b/test/orm/inheritance/basic.py @@ -84,7 +84,7 @@ class GetTest(ORMTest): Column('bar_id', Integer, ForeignKey('bar.id')), Column('data', String(20))) - def create_test(polymorphic): + def create_test(polymorphic, name): def test_get(self): class Foo(object): pass @@ -145,11 +145,11 @@ class GetTest(ORMTest): assert sess.query(Blub).get(bl.id) == bl self.assert_sql_count(testbase.db, go, 3) - + test_get.__name__ = name return test_get - test_get_polymorphic = create_test(True) - test_get_nonpolymorphic = create_test(False) + test_get_polymorphic = create_test(True, 'test_get_polymorphic') + test_get_nonpolymorphic = create_test(False, 'test_get_nonpolymorphic') class ConstructionTest(ORMTest): diff --git a/test/orm/lazy_relations.py b/test/orm/lazy_relations.py index b8e92c1637..97eda30063 100644 --- a/test/orm/lazy_relations.py +++ b/test/orm/lazy_relations.py @@ -8,12 +8,10 @@ from testlib import * from testlib.fixtures import * from query import QueryTest -class LazyTest(QueryTest): +class LazyTest(FixtureTest): keep_mappers = False - - def setup_mappers(self): - pass - + keep_data = True + def test_basic(self): mapper(User, users, properties={ 'addresses':relation(mapper(Address, addresses), lazy=True) @@ -275,13 +273,10 @@ class LazyTest(QueryTest): assert a.user is u1 -class M2OGetTest(QueryTest): +class M2OGetTest(FixtureTest): keep_mappers = False - keep_data = False + keep_data = True - def setup_mappers(self): - pass - def test_m2o_noload(self): """test that a NULL foreign key doesn't trigger a lazy load""" mapper(User, users) diff --git a/test/orm/mapper.py b/test/orm/mapper.py index fe763c0bed..5ed4795941 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -67,67 +67,12 @@ class MapperTest(MapperSuperTest): u2 = s.query(User).filter_by(user_name='jack').one() assert u is u2 - def test_refresh(self): - mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), backref='user')}) - s = create_session() - u = s.get(User, 7) - u.user_name = 'foo' - a = Address() - assert object_session(a) is None - u.addresses.append(a) - - self.assert_(a in u.addresses) - - s.refresh(u) - - # its refreshed, so not dirty - self.assert_(u not in s.dirty) - - # username is back to the DB - self.assert_(u.user_name == 'jack') - - self.assert_(a not in u.addresses) - - u.user_name = 'foo' - u.addresses.append(a) - # now its dirty - self.assert_(u in s.dirty) - self.assert_(u.user_name == 'foo') - self.assert_(a in u.addresses) - s.expire(u) - - # get the attribute, it refreshes - self.assert_(u.user_name == 'jack') - self.assert_(a not in u.addresses) def test_compileonsession(self): m = mapper(User, users) session = create_session() session.connection(m) - def test_expirecascade(self): - mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), cascade="all, refresh-expire")}) - s = create_session() - u = s.get(User, 8) - u.addresses[0].email_address = 'someotheraddress' - s.expire(u) - assert u.addresses[0].email_address == 'ed@wood.com' - - def test_refreshwitheager(self): - """test that a refresh/expire operation loads rows properly and sends correct "isnew" state to eager loaders""" - mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)}) - s = create_session() - u = s.get(User, 8) - assert len(u.addresses) == 3 - s.refresh(u) - assert len(u.addresses) == 3 - - s = create_session() - u = s.get(User, 8) - assert len(u.addresses) == 3 - s.expire(u) - assert len(u.addresses) == 3 - def test_incompletecolumns(self): """test loading from a select which does not contain all columns""" mapper(Address, addresses) @@ -186,71 +131,6 @@ class MapperTest(MapperSuperTest): except Exception, e: assert e is ex - def test_refresh_lazy(self): - """test that when a lazy loader is set as a trigger on an object's attribute (at the attribute level, not the class level), a refresh() operation doesnt fire the lazy loader or create any problems""" - s = create_session() - mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))}) - q2 = s.query(User).options(lazyload('addresses')) - u = q2.selectfirst(users.c.user_id==8) - def go(): - s.refresh(u) - self.assert_sql_count(testbase.db, go, 1) - - def test_expire(self): - """test the expire function""" - s = create_session() - mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)}) - u = s.get(User, 7) - assert(len(u.addresses) == 1) - u.user_name = 'foo' - del u.addresses[0] - s.expire(u) - # test plain expire - self.assert_(u.user_name =='jack') - self.assert_(len(u.addresses) == 1) - - # we're changing the database here, so if this test fails in the middle, - # it'll screw up the other tests which are hardcoded to 7/'jack' - u.user_name = 'foo' - s.flush() - # change the value in the DB - users.update(users.c.user_id==7, values=dict(user_name='jack')).execute() - s.expire(u) - # object isnt refreshed yet, using dict to bypass trigger - self.assert_(u.__dict__.get('user_name') != 'jack') - # do a select - s.query(User).select() - # test that it refreshed - self.assert_(u.__dict__['user_name'] == 'jack') - - # object should be back to normal now, - # this should *not* produce a SELECT statement (not tested here though....) - self.assert_(u.user_name =='jack') - - @testing.fails_on('maxdb') - def test_refresh2(self): - """test a hang condition that was occuring on expire/refresh""" - - s = create_session() - m1 = mapper(Address, addresses) - - m2 = mapper(User, users, properties = dict(addresses=relation(Address,private=True,lazy=False)) ) - u=User() - u.user_name='Justin' - a = Address() - a.address_id=17 # to work around the hardcoded IDs in this test suite.... - u.addresses.append(a) - s.flush() - s.clear() - u = s.query(User).selectfirst() - print u.user_name - - #ok so far - s.expire(u) #hangs when - print u.user_name #this line runs - - s.refresh(u) #hangs - def test_props(self): m = mapper(User, users, properties = { 'addresses' : relation(mapper(Address, addresses)) @@ -299,7 +179,18 @@ class MapperTest(MapperSuperTest): sess.save(u3) sess.flush() sess.rollback() - + + def test_illegal_non_primary(self): + mapper(User, users) + mapper(Address, addresses) + try: + mapper(User, users, non_primary=True, properties={ + 'addresses':relation(Address) + }).compile() + assert False + except exceptions.ArgumentError, e: + assert "Attempting to assign a new relation 'addresses' to a non-primary mapper on class 'User'" in str(e) + def test_propfilters(self): t = Table('person', MetaData(), Column('id', Integer, primary_key=True), diff --git a/test/orm/query.py b/test/orm/query.py index 4f19c2c32a..438bc9634f 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -12,16 +12,11 @@ from testlib.fixtures import * class QueryTest(FixtureTest): keep_mappers = True keep_data = True - + def setUpAll(self): super(QueryTest, self).setUpAll() - install_fixture_data() self.setup_mappers() - def tearDownAll(self): - clear_mappers() - super(QueryTest, self).tearDownAll() - def setup_mappers(self): mapper(User, users, properties={ 'addresses':relation(Address, backref='user'), diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index b985cc8a50..28efbf056c 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -119,8 +119,10 @@ class VersioningTest(ORMTest): except exceptions.ConcurrentModificationError, e: assert True # reload it + print "RELOAD" s1.query(Foo).load(f1s1.id) # now assert version OK + print "VERSIONCHECK" s1.query(Foo).with_lockmode('read').get(f1s1.id) # assert brand new load is OK too diff --git a/test/testlib/fixtures.py b/test/testlib/fixtures.py index 9105a29378..022fce0943 100644 --- a/test/testlib/fixtures.py +++ b/test/testlib/fixtures.py @@ -193,6 +193,17 @@ def install_fixture_data(): ) class FixtureTest(ORMTest): + refresh_data = False + + def setUpAll(self): + super(FixtureTest, self).setUpAll() + if self.keep_data: + install_fixture_data() + + def setUp(self): + if self.refresh_data: + install_fixture_data() + def define_tables(self, meta): pass FixtureTest.metadata = metadata