columns to the result set. This eliminates a JOIN from all eager loads
with LIMIT/OFFSET. [ticket:843]
+ - session.refresh() and session.expire() now support an additional argument
+ "attribute_names", a list of individual attribute keynames to be refreshed
+ or expired, allowing partial reloads of attributes on an already-loaded
+ instance.
+
- Mapped classes may now define __eq__, __hash__, and __nonzero__ methods
with arbitrary sementics. The orm now handles all mapped instances on
an identity-only basis. (e.g. 'is' vs '==') [ticket:676]
session.expire(obj1)
session.expire(obj2)
+`refresh()` and `expire()` also support being passed a list of individual attribute names in which to be refreshed. These names can reference any attribute, column-based or relation based:
+
+ {python}
+ # immediately re-load the attributes 'hello', 'world' on obj1, obj2
+ session.refresh(obj1, ['hello', 'world'])
+ session.refresh(obj2, ['hello', 'world'])
+
+ # expire the attriibutes 'hello', 'world' objects obj1, obj2, attributes will be reloaded
+ # on the next access:
+ session.expire(obj1, ['hello', 'world'])
+ session.expire(obj2, ['hello', 'world'])
+
## Cascades
Mappers support the concept of configurable *cascade* behavior on `relation()`s. This behavior controls how the Session should treat the instances that have a parent-child relationship with another instance that is operated upon by the Session. Cascade is indicated as a comma-separated list of string keywords, with the possible values `all`, `delete`, `save-update`, `refresh-expire`, `merge`, `expunge`, and `delete-orphan`.
if callable_ is None:
self.initialize(state)
else:
- state.callables[self] = callable_
+ state.callables[self.key] = callable_
def _get_callable(self, state):
- if self in state.callables:
- return state.callables[self]
+ if self.key in state.callables:
+ return state.callables[self.key]
elif self.callable_ is not None:
return self.callable_(state.obj())
else:
return None
- def reset(self, state):
- """Remove any per-instance callable functions corresponding to
- this ``InstrumentedAttribute``'s attribute from the given
- object, and remove this ``InstrumentedAttribute``'s attribute
- from the given object's dictionary.
- """
-
- try:
- del state.callables[self]
- except KeyError:
- pass
- self.clear(state)
-
- def clear(self, state):
- """Remove this ``InstrumentedAttribute``'s attribute from the given object's dictionary.
-
- Subsequent calls to ``getattr(obj, key)`` will raise an
- ``AttributeError`` by default.
- """
-
- try:
- del state.dict[self.key]
- except KeyError:
- pass
-
def check_mutable_modified(self, state):
return False
try:
return state.dict[self.key]
except KeyError:
- # if an instance-wide "trigger" was set, call that
- # and start again
- if state.trigger:
- state.call_trigger()
- return self.get(state, passive=passive)
callable_ = self._get_callable(state)
if callable_ is not None:
if value is not ATTR_WAS_SET:
return self.set_committed_value(state, value)
else:
+ if self.key not in state.dict:
+ return self.get(state, passive=passive)
return state.dict[self.key]
else:
# Return a new, empty value
state.dict[self.key] = value
return value
- def set_raw_value(self, state, value):
- state.dict[self.key] = value
- return value
-
def fire_append_event(self, state, value, initiator):
state.modified = True
if self.trackparent and value is not None:
if copy_function is None:
copy_function = self.__copy
self.copy = copy_function
-
+ self.accepts_global_callable = True
+
def __copy(self, item):
# scalar values are assumed to be immutable unless a copy function
# is passed
if initiator is self:
return
- # if an instance-wide "trigger" was set, call that
- if state.trigger:
- state.call_trigger()
-
state.dict[self.key] = value
state.modified=True
compare_function=compare_function, mutable_scalars=mutable_scalars, **kwargs)
if compare_function is None:
self.is_equal = identity_equal
+ self.accepts_global_callable = False
def delete(self, state):
old = self.get(state)
if initiator is self:
return
- # if an instance-wide "trigger" was set, call that
- if state.trigger:
- state.call_trigger()
-
old = self.get(state)
state.dict[self.key] = value
self.fire_replace_event(state, value, old, initiator)
copy_function = self.__copy
self.copy = copy_function
+ self.accepts_global_callable = False
+
if typecallable is None:
typecallable = list
self.collection_factory = \
elif setting_type == dict:
value = value.values()
- # if an instance-wide "trigger" was set, call that
- if state.trigger:
- state.call_trigger()
-
old = self.get(state)
old_collection = self.get_collection(state, old)
class InstanceState(object):
"""tracks state information at the instance level."""
- __slots__ = 'class_', 'obj', 'dict', 'committed_state', 'modified', 'trigger', 'callables', 'parents', 'instance_dict', '_strong_obj'
+ __slots__ = 'class_', 'obj', 'dict', 'committed_state', 'modified', 'trigger', 'callables', 'parents', 'instance_dict', '_strong_obj', 'expired_attributes'
def __init__(self, obj):
self.class_ = obj.__class__
self.obj = weakref.ref(obj, self.__cleanup)
self.dict = obj.__dict__
- self.committed_state = None
+ self.committed_state = {}
self.modified = False
self.trigger = None
self.callables = {}
self.dict = self.obj().__dict__
self.callables = {}
self.trigger = None
+
+ def initialize(self, key):
+ getattr(self.class_, key).impl.initialize(self)
- def call_trigger(self):
- trig = self.trigger
- self.trigger = None
- trig()
+ def set_callable(self, key, callable_):
+ self.dict.pop(key, None)
+ self.callables[key] = callable_
+
+ def __fire_trigger(self):
+ self.trigger(self.obj(), self.expired_attributes)
+ for k in self.expired_attributes:
+ self.callables.pop(k, None)
+ self.expired_attributes.clear()
+ return ATTR_WAS_SET
+
+ def expire_attributes(self, attribute_names):
+ if not hasattr(self, 'expired_attributes'):
+ self.expired_attributes = util.Set()
+ if attribute_names is None:
+ for attr in self.class_._sa_attribute_manager.managed_attributes(self.class_):
+ self.dict.pop(attr.impl.key, None)
+ self.callables[attr.impl.key] = self.__fire_trigger
+ self.expired_attributes.add(attr.impl.key)
+ else:
+ for key in attribute_names:
+ self.dict.pop(key, None)
+
+ if not getattr(self.class_, key).impl.accepts_global_callable:
+ continue
+
+ self.callables[key] = self.__fire_trigger
+ self.expired_attributes.add(key)
+
+ def reset(self, key):
+ """remove the given attribute and any callables associated with it."""
+
+ self.dict.pop(key, None)
+ self.callables.pop(key, None)
+
+ def clear(self):
+ """clear all attributes from the instance."""
+
+ for attr in self.class_._sa_attribute_manager.managed_attributes(self.class_):
+ self.dict.pop(attr.impl.key, None)
+
+ def commit(self, keys):
+ """commit all attributes named in the given list of key names.
+
+ This is used by a partial-attribute load operation to mark committed those attributes
+ which were refreshed from the database.
+ """
+
+ for key in keys:
+ getattr(self.class_, key).impl.commit_to_state(self)
+
+ def commit_all(self):
+ """commit all attributes unconditionally.
+
+ This is used after a flush() or a regular instance load or refresh operation
+ to mark committed all populated attributes.
+ """
- def commit(self, manager, obj):
self.committed_state = {}
self.modified = False
- for attr in manager.managed_attributes(obj.__class__):
+ for attr in self.class_._sa_attribute_manager.managed_attributes(self.class_):
attr.impl.commit_to_state(self)
# remove strong ref
self._strong_obj = None
- def rollback(self, manager, obj):
- if not self.committed_state:
- manager._clear(obj)
- else:
- for attr in manager.managed_attributes(obj.__class__):
- if attr.impl.key in self.committed_state:
- if not hasattr(attr.impl, 'get_collection'):
- obj.__dict__[attr.impl.key] = self.committed_state[attr.impl.key]
- else:
- collection = attr.impl.get_collection(self)
- collection.clear_without_event()
- for item in self.committed_state[attr.impl.key]:
- collection.append_without_event(item)
- else:
- if attr.impl.key in self.dict:
- del self.dict[attr.impl.key]
class InstanceDict(UserDict.UserDict):
"""similar to WeakValueDictionary, but wired towards 'state' objects."""
def clear_attribute_cache(self):
self._attribute_cache.clear()
- def rollback(self, *obj):
- """Retrieve the committed history for each object in the given
- list, and rolls back the attributes each instance to their
- original value.
- """
-
- for o in obj:
- o._state.rollback(self, o)
-
- def _clear(self, obj):
- for attr in self.managed_attributes(obj.__class__):
- try:
- del obj.__dict__[attr.impl.key]
- except KeyError:
- pass
-
- def commit(self, *obj):
- """Establish the "committed state" for each object in the given list."""
-
- for o in obj:
- o._state.commit(self, o)
-
def managed_attributes(self, class_):
"""Return a list of all ``InstrumentedAttribute`` objects
associated with the given class.
else:
return [x]
- def trigger_history(self, obj, callable):
- """Clear all managed object attributes and places the given
- `callable` as an attribute-wide *trigger*, which will execute
- upon the next attribute access, after which the trigger is
- removed.
- """
-
- s = obj._state
- self._clear(obj)
- s.committed_state = None
- s.trigger = callable
-
- def untrigger_history(self, obj):
- """Remove a trigger function set by trigger_history.
-
- Does not restore the previous state of the object.
- """
-
- obj._state.trigger = None
-
- def has_trigger(self, obj):
- """Return True if the given object has a trigger function set
- by ``trigger_history()``.
- """
-
- return obj._state.trigger is not None
-
- def reset_instance_attribute(self, obj, key):
- """Remove any per-instance callable functions corresponding to
- given attribute `key` from the given object, and remove this
- attribute from the given object's dictionary.
- """
-
- attr = getattr(obj.__class__, key)
- attr.impl.reset(obj._state)
-
- def is_class_managed(self, class_, key):
- """Return True if the given `key` correponds to an
- instrumented property on the given class.
- """
- return hasattr(class_, key) and isinstance(getattr(class_, key), InstrumentedAttribute)
-
def has_parent(self, class_, obj, key, optimistic=False):
return getattr(class_, key).impl.hasparent(obj._state, optimistic=optimistic)
- def init_instance_attribute(self, obj, key, callable_=None, clear=False):
- """Initialize an attribute on an instance to either a blank
- value, cancelling out any class- or instance-level callables
- that were present, or if a `callable` is supplied set the
- callable to be invoked when the attribute is next accessed.
- """
-
- getattr(obj.__class__, key).impl.set_callable(obj._state, callable_, clear=clear)
-
def _create_prop(self, class_, key, uselist, callable_, typecallable, useobject, **kwargs):
"""Create a scalar property object, defaulting to
``InstrumentedAttribute``, which will communicate change
setattr(class_, key, InstrumentedAttribute(self._create_prop(class_, key, uselist, callable_, useobject=useobject,
typecallable=typecallable, **kwargs), comparator=comparator))
- def set_raw_value(self, instance, key, value):
- getattr(instance.__class__, key).impl.set_raw_value(instance._state, value)
-
- def set_committed_value(self, instance, key, value):
- getattr(instance.__class__, key).impl.set_committed_value(instance._state, value)
-
def init_collection(self, instance, key):
"""Initialize a collection attribute and return the collection adapter."""
attr = getattr(instance.__class__, key).impl
prop = self._getpropbycolumn(c, raiseerror=False)
if prop is None:
continue
- deferred_props.append(prop)
+ deferred_props.append(prop.key)
continue
if c.primary_key or not c.key in params:
continue
self.set_attr_by_column(obj, c, params[c.key])
if deferred_props:
- deferred_load(obj, props=deferred_props)
+ expire_instance(obj, deferred_props)
def delete_obj(self, objects, uowtransaction):
"""Issue ``DELETE`` statements for a list of objects.
return self.__surrogate_mapper or self
- def _instance(self, context, row, result = None, skip_polymorphic=False):
+ def _instance(self, context, row, result=None, skip_polymorphic=False, extension=None, only_load_props=None, refresh_instance=None):
"""Pull an object instance from the given row and append it to
the given result list.
on the instance to also process extra information in the row.
"""
- # apply ExtensionOptions applied to the Query to this mapper,
- # but only if our mapper matches.
- # TODO: what if our mapper inherits from the mapper (i.e. as in a polymorphic load?)
- if context.mapper is self:
- extension = context.extension
- else:
+ if not extension:
extension = self.extension
-
+
if 'translate_row' in extension.methods:
ret = extension.translate_row(self, context, row)
if ret is not EXT_CONTINUE:
row = ret
- if not skip_polymorphic and self.polymorphic_on is not None:
- discriminator = row[self.polymorphic_on]
- if discriminator is not None:
- mapper = self.polymorphic_map[discriminator]
- if mapper is not self:
- if ('polymorphic_fetch', mapper) not in context.attributes:
- context.attributes[('polymorphic_fetch', mapper)] = (self, [t for t in mapper.tables if t not in self.tables])
- row = self.translate_row(mapper, row)
- return mapper._instance(context, row, result=result, skip_polymorphic=True)
+ if refresh_instance is None:
+ if not skip_polymorphic and self.polymorphic_on is not None:
+ discriminator = row[self.polymorphic_on]
+ if discriminator is not None:
+ mapper = self.polymorphic_map[discriminator]
+ if mapper is not self:
+ if ('polymorphic_fetch', mapper) not in context.attributes:
+ context.attributes[('polymorphic_fetch', mapper)] = (self, [t for t in mapper.tables if t not in self.tables])
+ row = self.translate_row(mapper, row)
+ return mapper._instance(context, row, result=result, skip_polymorphic=True)
- # look in main identity map. if its there, we dont do anything to it,
- # including modifying any of its related items lists, as its already
- # been exposed to being modified by the application.
- identitykey = self.identity_key_from_row(row)
+ # determine identity key
+ if refresh_instance:
+ identitykey = refresh_instance._instance_key
+ else:
+ identitykey = self.identity_key_from_row(row)
(session_identity_map, local_identity_map) = (context.session.identity_map, context.identity_map)
-
+
+ # look in main identity map. if present, we only populate
+ # if repopulate flags are set. this block returns the instance.
if identitykey in session_identity_map:
instance = session_identity_map[identitykey]
if identitykey not in local_identity_map:
local_identity_map[identitykey] = instance
isnew = True
- if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
- self.populate_instance(context, instance, row, instancekey=identitykey, isnew=isnew)
+ if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, instancekey=identitykey, isnew=isnew, only_load_props=only_load_props) is EXT_CONTINUE:
+ self.populate_instance(context, instance, row, instancekey=identitykey, isnew=isnew, only_load_props=only_load_props)
if 'append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
if result is not None:
result.append(instance)
+
return instance
- else:
- if self.__should_log_debug:
- self.__log_debug("_instance(): identity key %s not in session" % str(identitykey))
+
+ elif self.__should_log_debug:
+ self.__log_debug("_instance(): identity key %s not in session" % str(identitykey))
- # look in result-local identitymap for it.
+ # look in identity map which is local to this load operation
if identitykey not in local_identity_map:
+ # check that sufficient primary key columns are present
if self.allow_null_pks:
# check if *all* primary key cols in the result are None - this indicates
# an instance of the object is not present in the row.
if None in identitykey[1]:
return None
- # plugin point
if 'create_instance' in extension.methods:
instance = extension.create_instance(self, context, row, self.class_)
if instance is EXT_CONTINUE:
instance = attribute_manager.new_instance(self.class_)
instance._entity_name = self.entity_name
+ instance._instance_key = identitykey
+
if self.__should_log_debug:
self.__log_debug("_instance(): created new instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey)))
+
local_identity_map[identitykey] = instance
isnew = True
else:
+ # instance is already present
instance = local_identity_map[identitykey]
isnew = False
- # call further mapper properties on the row, to pull further
- # instances from the row and possibly populate this item.
+ # populate. note that we still call this for an instance already loaded as additional collection state is present
+ # in subsequent rows (i.e. eagerly loaded collections)
flags = {'instancekey':identitykey, 'isnew':isnew}
- if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, **flags) is EXT_CONTINUE:
- self.populate_instance(context, instance, row, **flags)
+ if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, **flags) is EXT_CONTINUE:
+ self.populate_instance(context, instance, row, only_load_props=only_load_props, **flags)
if 'append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, **flags) is EXT_CONTINUE:
if result is not None:
result.append(instance)
-
- instance._instance_key = identitykey
return instance
"""
if tomapper in self._row_translators:
+ # row translators are cached based on target mapper
return self._row_translators[tomapper](row)
else:
translator = create_row_adapter(self.mapped_table, tomapper.mapped_table, equivalent_columns=self._equivalent_columns)
self._row_translators[tomapper] = translator
return translator(row)
- def populate_instance(self, selectcontext, instance, row, ispostselect=None, isnew=False, **flags):
+ def populate_instance(self, selectcontext, instance, row, ispostselect=None, isnew=False, only_load_props=None, **flags):
"""populate an instance from a result row."""
snapshot = selectcontext.path + (self,)
existing_populators = []
post_processors = []
for prop in self.__props.values():
+ if only_load_props and prop.key not in only_load_props:
+ continue
(newpop, existingpop, post_proc) = selectcontext.exec_with_path(self, prop.key, prop.create_row_processor, selectcontext, self, row)
if newpop is not None:
new_populators.append((prop.key, newpop))
existing_populators.append((prop.key, existingpop))
if post_proc is not None:
post_processors.append(post_proc)
-
+
+ # install a post processor for immediate post-load of joined-table inheriting mappers
poly_select_loader = self._get_poly_select_loader(selectcontext, row)
if poly_select_loader is not None:
post_processors.append(poly_select_loader)
if self.backref is not None:
self.backref.compile(self)
- elif not sessionlib.attribute_manager.is_class_managed(self.parent.class_, self.key):
+ elif not mapper.class_mapper(self.parent.class_).get_property(self.key, raiseerr=False):
raise exceptions.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'. New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (self.key, self.parent.class_.__name__, self.parent.class_.__name__))
super(PropertyLoader, self).do_init()
return attributes.GenericBackrefExtension(self.key)
-def deferred_load(instance, props):
- """set multiple instance attributes to 'deferred' or 'lazy' load, for the given set of MapperProperty objects.
-
- this will remove the current value of the attribute and set a per-instance
- callable to fire off when the instance is next accessed.
-
- for column-based properties, aggreagtes them into a single list against a single deferred loader
- so that a single column access loads all columns
-
- """
-
- if not props:
- return
- column_props = [p for p in props if isinstance(p, ColumnProperty)]
- callable_ = column_props[0]._get_strategy(strategies.DeferredColumnLoader).setup_loader(instance, props=column_props)
- for p in column_props:
- sessionlib.attribute_manager.init_instance_attribute(instance, p.key, callable_=callable_, clear=True)
-
- for p in [p for p in props if isinstance(p, PropertyLoader)]:
- callable_ = p._get_strategy(strategies.LazyLoader).setup_loader(instance)
- sessionlib.attribute_manager.init_instance_attribute(instance, p.key, callable_=callable_, clear=True)
-
mapper.ColumnProperty = ColumnProperty
-mapper.deferred_load = deferred_load
self._attributes = {}
self._current_path = ()
self._primary_adapter=None
+ self._only_load_props = None
+ self._refresh_instance = None
def _clone(self):
q = Query.__new__(Query)
key column values in the order of the table def's primary key
columns.
"""
-
+ print "LOAD CHECK1"
ret = self._extension.load(self, ident, **kwargs)
if ret is not mapper.EXT_CONTINUE:
return ret
+ print "LOAD CHECK2"
key = self.mapper.identity_key_from_primary_key(ident)
- instance = self._get(key, ident, reload=True, **kwargs)
+ instance = self.populate_existing()._get(key, ident, **kwargs)
+ print "LOAD CHECK3"
if instance is None and raiseerr:
raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
return instance
return self._execute_and_instances(context)
def _execute_and_instances(self, querycontext):
- result = self.session.execute(querycontext.statement, params=self._params, mapper=self.mapper)
+ result = self.session.execute(querycontext.statement, params=self._params, mapper=self.mapper, instance=self._refresh_instance)
try:
return iter(self.instances(result, querycontext=querycontext))
finally:
and add_column().
"""
- self.__log_debug("instances()")
-
session = self.session
context = kwargs.pop('querycontext', None)
result = []
else:
result = util.UniqueAppender([])
-
+
+ primary_mapper_args = dict(extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance)
+
for row in cursor.fetchall():
if self._primary_adapter:
- self.select_mapper._instance(context, self._primary_adapter(row), result)
+ self.select_mapper._instance(context, self._primary_adapter(row), result, **primary_mapper_args)
else:
- self.select_mapper._instance(context, row, result)
+ self.select_mapper._instance(context, row, result, **primary_mapper_args)
for proc in process:
proc[0](context, row)
for instance in context.identity_map.values():
context.attributes.get(('populating_mapper', id(instance)), object_mapper(instance))._post_instance(context, instance)
-
+
+ if context.refresh_instance and context.only_load_props and context.refresh_instance._instance_key in context.identity_map:
+ # if refreshing partial instance, do special state commit
+ # affecting only the refreshed attributes
+ context.refresh_instance._state.commit(context.only_load_props)
+ del context.identity_map[context.refresh_instance._instance_key]
+
# store new stuff in the identity map
for instance in context.identity_map.values():
session._register_persistent(instance)
-
+
if mappers_or_columns:
return list(util.OrderedSet(zip(*([result] + [o[1] for o in process]))))
else:
return result.data
- def _get(self, key, ident=None, reload=False, lockmode=None):
+ def _get(self, key=None, ident=None, refresh_instance=None, lockmode=None, only_load_props=None):
lockmode = lockmode or self._lockmode
- if not reload and not self.mapper.always_refresh and lockmode is None:
+ if not self._populate_existing and not refresh_instance and not self.mapper.always_refresh and lockmode is None:
try:
return self.session.identity_map[key]
except KeyError:
pass
-
+
if ident is None:
- ident = key[1]
+ if key is not None:
+ ident = key[1]
else:
ident = util.to_list(ident)
- params = {}
+
+ q = self
- (_get_clause, _get_params) = self.select_mapper._get_clause
- for i, primary_key in enumerate(self.primary_key_columns):
- try:
- params[_get_params[primary_key].key] = ident[i]
- except IndexError:
- raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in self.primary_key_columns]))
+ if ident is not None:
+ params = {}
+ (_get_clause, _get_params) = self.select_mapper._get_clause
+ q = q.filter(_get_clause)
+ for i, primary_key in enumerate(self.primary_key_columns):
+ try:
+ params[_get_params[primary_key].key] = ident[i]
+ except IndexError:
+ raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in self.primary_key_columns]))
+ q = q.params(params)
+
try:
- q = self
if lockmode is not None:
q = q.with_lockmode(lockmode)
- q = q.filter(_get_clause)
- q = q.params(params)._select_context_options(populate_existing=reload, version_check=(lockmode is not None))
+ q = q._select_context_options(populate_existing=refresh_instance is not None, version_check=(lockmode is not None), only_load_props=only_load_props, refresh_instance=refresh_instance)
+ q = q.order_by(None)
# call using all() to avoid LIMIT compilation complexity
return q.all()[0]
except IndexError:
# TODO: doing this off the select_mapper. if its the polymorphic mapper, then
# it has no relations() on it. should we compile those too into the query ? (i.e. eagerloads)
for value in self.select_mapper.iterate_properties:
- context.exec_with_path(self.select_mapper, value.key, value.setup, context)
+ if self._only_load_props and value.key not in self._only_load_props:
+ continue
+ context.exec_with_path(self.select_mapper, value.key, value.setup, context, only_load_props=self._only_load_props)
# additional entities/columns, add those to selection criterion
for tup in self._entities:
statement.append_order_by(*context.eager_order_by)
else:
statement = sql.select(context.primary_columns + context.secondary_columns, whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **self._select_args())
-
if context.eager_joins:
statement.append_from(context.eager_joins, _copy_collection=False)
q._select_context_options(**kwargs)
return list(q)
- def _select_context_options(self, populate_existing=None, version_check=None): #pragma: no cover
- if populate_existing is not None:
+ def _select_context_options(self, populate_existing=None, version_check=None, only_load_props=None, refresh_instance=None): #pragma: no cover
+ if populate_existing:
self._populate_existing = populate_existing
- if version_check is not None:
+ if version_check:
self._version_check = version_check
+ if refresh_instance is not None:
+ self._refresh_instance = refresh_instance
+ if only_load_props:
+ self._only_load_props = util.Set(only_load_props)
return self
def join_to(self, key): #pragma: no cover
self.statement = None
self.populate_existing = query._populate_existing
self.version_check = query._version_check
+ self.only_load_props = query._only_load_props
+ self.refresh_instance = query._refresh_instance
self.identity_map = {}
self.path = ()
self.primary_columns = []
import weakref
from sqlalchemy import util, exceptions, sql, engine
-from sqlalchemy.orm import unitofwork, query, util as mapperutil
+from sqlalchemy.orm import unitofwork, query, attributes, util as mapperutil
from sqlalchemy.orm.mapper import object_mapper as _object_mapper
from sqlalchemy.orm.mapper import class_mapper as _class_mapper
from sqlalchemy.orm.mapper import Mapper
resources of the underlying ``Connection``.
"""
- engine = self.get_bind(mapper, clause=clause)
+ engine = self.get_bind(mapper, clause=clause, **kwargs)
- return self.__connection(engine, close_with_result=True).execute(clause, params or {}, **kwargs)
+ return self.__connection(engine, close_with_result=True).execute(clause, params or {})
def scalar(self, clause, params=None, mapper=None, **kwargs):
"""Like execute() but return a scalar result."""
entity_name = kwargs.pop('entity_name', None)
return self.query(class_, entity_name=entity_name).load(ident, **kwargs)
- def refresh(self, obj):
- """Reload the attributes for the given object from the
- database, clear any changes made.
+ def refresh(self, obj, attribute_names=None):
+ """Refresh the attributes on the given instance.
+
+ When called, a query will be issued
+ to the database which will refresh all attributes with their
+ current value.
+
+ Lazy-loaded relational attributes will remain lazily loaded, so that
+ the instance-wide refresh operation will be followed
+ immediately by the lazy load of that attribute.
+
+ Eagerly-loaded relational attributes will eagerly load within the
+ single refresh operation.
+
+ The ``attribute_names`` argument is an iterable collection
+ of attribute names indicating a subset of attributes to be
+ refreshed.
"""
self._validate_persistent(obj)
- if self.query(obj.__class__)._get(obj._instance_key, reload=True) is None:
+
+ if self.query(obj.__class__)._get(obj._instance_key, refresh_instance=obj, only_load_props=attribute_names) is None:
raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(obj))
- def expire(self, obj):
- """Mark the given object as expired.
-
- This will add an instrumentation to all mapped attributes on
- the instance such that when an attribute is next accessed, the
- session will reload all attributes on the instance from the
- database.
+ def expire(self, obj, attribute_names=None):
+ """Expire the attributes on the given instance.
+
+ The instance's attributes are instrumented such that
+ when an attribute is next accessed, a query will be issued
+ to the database which will refresh all attributes with their
+ current value.
+
+ Lazy-loaded relational attributes will remain lazily loaded, so that
+ triggering one will incur the instance-wide refresh operation, followed
+ immediately by the lazy load of that attribute.
+
+ Eagerly-loaded relational attributes will eagerly load within the
+ single refresh operation.
+
+ The ``attribute_names`` argument is an iterable collection
+ of attribute names indicating a subset of attributes to be
+ expired.
"""
-
- for c in [obj] + list(_object_mapper(obj).cascade_iterator('refresh-expire', obj)):
- self._expire_impl(c)
+
+ if attribute_names:
+ self._validate_persistent(obj)
+ expire_instance(obj, attribute_names=attribute_names)
+ else:
+ for c in [obj] + list(_object_mapper(obj).cascade_iterator('refresh-expire', obj)):
+ self._validate_persistent(obj)
+ expire_instance(c, None)
def prune(self):
"""Removes unreferenced instances cached in the identity map.
return self.uow.prune_identity_map()
- def _expire_impl(self, obj):
- self._validate_persistent(obj)
-
- def exp():
- if self.query(obj.__class__)._get(obj._instance_key, reload=True) is None:
- raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(obj))
-
- attribute_manager.trigger_history(obj, exp)
-
def is_expired(self, obj, unexpire=False):
"""Return True if the given object has been marked as expired."""
- ret = attribute_manager.has_trigger(obj)
+ ret = obj._state.trigger is not None
if ret and unexpire:
- attribute_manager.untrigger_history(obj)
+ obj._state.trigger = None
return ret
def expunge(self, object):
def _register_persistent(self, obj):
obj._sa_session_id = self.hash_key
self.identity_map[obj._instance_key] = obj
- attribute_manager.commit(obj)
+ obj._state.commit_all()
def _attach(self, obj):
old_id = getattr(obj, '_sa_session_id', None)
new = property(lambda s:s.uow.new,
doc="A ``Set`` of all objects marked as 'new' within this ``Session``.")
+def expire_instance(obj, attribute_names):
+ """standalone expire instance function.
+
+ installs a callable with the given instance's _state
+ which will fire off when any of the named attributes are accessed;
+ their existing value is removed.
+
+ If the list is None or blank, the entire instance is expired.
+ """
+
+ if obj._state.trigger is None:
+ def load_attributes(instance, attribute_names):
+ if object_session(instance).query(instance.__class__)._get(instance._instance_key, refresh_instance=instance, only_load_props=attribute_names) is None:
+ raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
+ obj._state.trigger = load_attributes
+
+ obj._state.expire_attributes(attribute_names)
+
+
# this is the AttributeManager instance used to provide attribute behavior on objects.
# to all the "global variable police" out there: its a stateless object.
unitofwork.object_session = object_session
from sqlalchemy.orm import mapper
mapper.attribute_manager = attribute_manager
+mapper.expire_instance = expire_instance
\ No newline at end of file
def new_execute(instance, row, isnew, **flags):
if isnew:
- loader = strategy.setup_loader(instance, props=props, create_statement=create_statement)
- sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=loader)
+ instance._state.set_callable(self.key, strategy.setup_loader(instance, props=props, create_statement=create_statement))
if self._should_log_debug:
self.logger.debug("Returning deferred column fetcher for %s %s" % (mapper, self.key))
"""Deferred column loader, a per-column or per-column-group lazy loader."""
def create_row_processor(self, selectcontext, mapper, row):
- if (self.group is not None and selectcontext.attributes.get(('undefer', self.group), False)) or self.columns[0] in row:
+ if self.columns[0] in row:
return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, mapper, row)
elif not self.is_class_level or len(selectcontext.options):
def new_execute(instance, row, **flags):
if self._should_log_debug:
self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key))
- sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self.setup_loader(instance))
+ instance._state.set_callable(self.key, self.setup_loader(instance))
return (new_execute, None, None)
else:
def new_execute(instance, row, **flags):
if self._should_log_debug:
self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key))
- sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
+ instance._state.reset(self.key)
return (new_execute, None, None)
def init(self):
self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__))
sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, callable_=self.setup_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator)
- def setup_query(self, context, **kwargs):
- if self.group is not None and context.attributes.get(('undefer', self.group), False):
+ def setup_query(self, context, only_load_props=None, **kwargs):
+ if \
+ (self.group is not None and context.attributes.get(('undefer', self.group), False)) or \
+ (only_load_props and self.key in only_load_props):
+
self.parent_property._get_strategy(ColumnLoader).setup_query(context, **kwargs)
def setup_loader(self, instance, props=None, create_statement=None):
raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
if create_statement is None:
- (clause, param_map) = localparent._get_clause
ident = instance._instance_key[1]
- params = {}
- for i, primary_key in enumerate(localparent.primary_key):
- params[param_map[primary_key].key] = ident[i]
- statement = sql.select([p.columns[0] for p in group], clause, from_obj=[localparent.mapped_table], use_labels=True)
+ session.query(localparent)._get(None, ident=ident, only_load_props=[p.key for p in group], refresh_instance=instance)
else:
statement, params = create_statement(instance)
-
- # TODO: have the "fetch of one row" operation go through the same channels as a query._get()
- # deferred load of several attributes should be a specialized case of a query refresh operation
- conn = session.connection(mapper=localparent, instance=instance)
- result = conn.execute(statement, params)
- try:
- row = result.fetchone()
- for prop in group:
- sessionlib.attribute_manager.set_committed_value(instance, prop.key, row[prop.columns[0]])
- return attributes.ATTR_WAS_SET
- finally:
- result.close()
-
+ session.query(localparent).from_statement(statement).params(params)._get(None, only_load_props=[p.key for p in group], refresh_instance=instance)
+ return attributes.ATTR_WAS_SET
return lazyload
DeferredColumnLoader.logger = logging.class_logger(DeferredColumnLoader)
self._should_log_debug = logging.is_debug_enabled(self.logger)
def _init_instance_attribute(self, instance, callable_=None):
- return sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=callable_)
+ if callable_:
+ instance._state.set_callable(self.key, callable_)
+ else:
+ instance._state.initialize(self.key)
def _register_attribute(self, class_, callable_=None, **kwargs):
self.logger.info("register managed %s attribute %s on class %s" % ((self.uselist and "list-holding" or "scalar"), self.key, self.parent.class_.__name__))
# so that the class-level lazy loader is executed when next referenced on this instance.
# this usually is not needed unless the constructor of the object referenced the attribute before we got
# to load data into it.
- sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
+ instance._state.reset(self.key)
return (new_execute, None, None)
def _create_lazy_clause(cls, prop, reverse_direction=False):
# parent object, bypassing InstrumentedAttribute
# event handlers.
#
- # FIXME: instead of...
- sessionlib.attribute_manager.set_raw_value(instance, self.key, self.select_mapper._instance(selectcontext, decorated_row, None))
- # bypass and set directly:
- #instance.__dict__[self.key] = self.select_mapper._instance(selectcontext, decorated_row, None)
+ instance.__dict__[self.key] = self.select_mapper._instance(selectcontext, decorated_row, None)
else:
# call _instance on the row, even though the object has been created,
# so that we further descend into properties
if hasattr(obj, '_sa_insert_order'):
delattr(obj, '_sa_insert_order')
self.identity_map[obj._instance_key] = obj
- attribute_manager.commit(obj)
+ obj._state.commit_all()
def register_new(self, obj):
"""register the given object as 'new' (i.e. unsaved) within this unit of work."""
'orm.lazy_relations',
'orm.eager_relations',
'orm.mapper',
+ 'orm.expire',
'orm.selectable',
'orm.collection',
'orm.generative',
from sqlalchemy import exceptions
from testlib import *
+ROLLBACK_SUPPORTED=False
+
# these test classes defined at the module
# level to support pickling
class MyTest(object):pass
print repr(u.__dict__)
self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
- manager.commit(u)
+ u._state.commit_all()
print repr(u.__dict__)
self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
u.email_address = 'foo@bar.com'
print repr(u.__dict__)
self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.email_address == 'foo@bar.com')
-
- manager.rollback(u)
- print repr(u.__dict__)
- self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
+
+ if ROLLBACK_SUPPORTED:
+ manager.rollback(u)
+ print repr(u.__dict__)
+ self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
def test_pickleness(self):
print repr(u.__dict__)
self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
- manager.commit(u, a)
+ u, a._state.commit_all()
print repr(u.__dict__)
self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
print repr(u.__dict__)
self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com')
- manager.rollback(u, a)
- print repr(u.__dict__)
- print repr(u.addresses[0].__dict__)
- self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
- self.assert_(len(manager.get_history(u, 'addresses').unchanged_items()) == 1)
+ if ROLLBACK_SUPPORTED:
+ manager.rollback(u, a)
+ print repr(u.__dict__)
+ print repr(u.addresses[0].__dict__)
+ self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
+ self.assert_(len(manager.get_history(u, 'addresses').unchanged_items()) == 1)
def test_backref(self):
class Student(object):pass
# create objects as if they'd been freshly loaded from the database (without history)
b = Blog()
p1 = Post()
- manager.init_instance_attribute(b, 'posts', lambda:[p1])
- manager.init_instance_attribute(p1, 'blog', lambda:b)
- manager.commit(p1, b)
+ b._state.set_callable('posts', lambda:[p1])
+ p1._state.set_callable('blog', lambda:b)
+ p1, b._state.commit_all()
# no orphans (called before the lazy loaders fire off)
assert manager.has_parent(Blog, p1, 'posts', optimistic=True)
x.element = 'this is the element'
hist = manager.get_history(x, 'element')
assert hist.added_items() == ['this is the element']
- manager.commit(x)
+ x._state.commit_all()
hist = manager.get_history(x, 'element')
assert hist.added_items() == []
assert hist.unchanged_items() == ['this is the element']
manager.register_attribute(Bar, 'id', uselist=False, useobject=True)
x = Foo()
- manager.commit(x)
+ x._state.commit_all()
x.col2.append(Bar(4))
h = manager.get_history(x, 'col2')
print h.added_items()
manager.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x], mutable_scalars=True, useobject=False)
x = Foo()
x.element = ['one', 'two', 'three']
- manager.commit(x)
+ x._state.commit_all()
x.element[1] = 'five'
assert manager.is_modified(x)
manager.register_attribute(Foo, 'element', uselist=False, useobject=False)
x = Foo()
x.element = ['one', 'two', 'three']
- manager.commit(x)
+ x._state.commit_all()
x.element[1] = 'five'
assert not manager.is_modified(x)
from query import QueryTest
-class DynamicTest(QueryTest):
+class DynamicTest(FixtureTest):
keep_mappers = False
-
- def setup_mappers(self):
- pass
-
+ keep_data = True
+
def test_basic(self):
mapper(User, users, properties={
'addresses':dynamic_loader(mapper(Address, addresses))
from testlib.fixtures import *
from query import QueryTest
-class EagerTest(QueryTest):
+class EagerTest(FixtureTest):
keep_mappers = False
-
+ keep_data = True
+
def setup_mappers(self):
pass
--- /dev/null
+"""test attribute/instance expiration, deferral of attributes, etc."""
+
+import testbase
+from sqlalchemy import *
+from sqlalchemy import exceptions
+from sqlalchemy.orm import *
+from testlib import *
+from testlib.fixtures import *
+
+class ExpireTest(FixtureTest):
+ keep_mappers = False
+ refresh_data = True
+
+ def test_expire(self):
+ mapper(User, users, properties={
+ 'addresses':relation(Address, backref='user'),
+ })
+ mapper(Address, addresses)
+
+ sess = create_session()
+ u = sess.query(User).get(7)
+ assert len(u.addresses) == 1
+ u.name = 'foo'
+ del u.addresses[0]
+ sess.expire(u)
+
+ assert 'name' not in u.__dict__
+
+ def go():
+ assert u.name == 'jack'
+ self.assert_sql_count(testbase.db, go, 1)
+ assert 'name' in u.__dict__
+
+ # we're changing the database here, so if this test fails in the middle,
+ # it'll screw up the other tests which are hardcoded to 7/'jack'
+ u.name = 'foo'
+ sess.flush()
+ # change the value in the DB
+ users.update(users.c.id==7, values=dict(name='jack')).execute()
+ sess.expire(u)
+ # object isnt refreshed yet, using dict to bypass trigger
+ assert u.__dict__.get('name') != 'jack'
+ # reload all
+ sess.query(User).all()
+ # test that it refreshed
+ assert u.__dict__['name'] == 'jack'
+
+ # object should be back to normal now,
+ # this should *not* produce a SELECT statement (not tested here though....)
+ assert u.name == 'jack'
+
+ def test_expire_doesntload_on_set(self):
+ mapper(User, users)
+
+ sess = create_session()
+ u = sess.query(User).get(7)
+
+ sess.expire(u, attribute_names=['name'])
+ def go():
+ u.name = 'somenewname'
+ self.assert_sql_count(testbase.db, go, 0)
+ sess.flush()
+ sess.clear()
+ assert sess.query(User).get(7).name == 'somenewname'
+
+ def test_expire_committed(self):
+ """test that the committed state of the attribute receives the most recent DB data"""
+ mapper(Order, orders)
+
+ sess = create_session()
+ o = sess.query(Order).get(3)
+ sess.expire(o)
+
+ orders.update(id=3).execute(description='order 3 modified')
+ assert o.isopen == 1
+ assert o._state.committed_state['description'] == 'order 3 modified'
+ def go():
+ sess.flush()
+ self.assert_sql_count(testbase.db, go, 0)
+
+ def test_expire_cascade(self):
+ mapper(User, users, properties={
+ 'addresses':relation(Address, cascade="all, refresh-expire")
+ })
+ mapper(Address, addresses)
+ s = create_session()
+ u = s.get(User, 8)
+ assert u.addresses[0].email_address == 'ed@wood.com'
+
+ u.addresses[0].email_address = 'someotheraddress'
+ s.expire(u)
+ u.name
+ print u._state.dict
+ assert u.addresses[0].email_address == 'ed@wood.com'
+
+ def test_expired_lazy(self):
+ mapper(User, users, properties={
+ 'addresses':relation(Address, backref='user'),
+ })
+ mapper(Address, addresses)
+
+ sess = create_session()
+ u = sess.query(User).get(7)
+
+ sess.expire(u)
+ assert 'name' not in u.__dict__
+ assert 'addresses' not in u.__dict__
+
+ def go():
+ assert u.addresses[0].email_address == 'jack@bean.com'
+ assert u.name == 'jack'
+ # two loads
+ self.assert_sql_count(testbase.db, go, 2)
+ assert 'name' in u.__dict__
+ assert 'addresses' in u.__dict__
+
+ def test_expired_eager(self):
+ mapper(User, users, properties={
+ 'addresses':relation(Address, backref='user', lazy=False),
+ })
+ mapper(Address, addresses)
+
+ sess = create_session()
+ u = sess.query(User).get(7)
+
+ sess.expire(u)
+ assert 'name' not in u.__dict__
+ assert 'addresses' not in u.__dict__
+
+ def go():
+ assert u.addresses[0].email_address == 'jack@bean.com'
+ assert u.name == 'jack'
+ # one load
+ self.assert_sql_count(testbase.db, go, 1)
+ assert 'name' in u.__dict__
+ assert 'addresses' in u.__dict__
+
+ def test_partial_expire(self):
+ mapper(Order, orders)
+
+ sess = create_session()
+ o = sess.query(Order).get(3)
+
+ sess.expire(o, attribute_names=['description'])
+ assert 'id' in o.__dict__
+ assert 'description' not in o.__dict__
+ assert o._state.committed_state['isopen'] == 1
+
+ orders.update(orders.c.id==3).execute(description='order 3 modified')
+
+ def go():
+ assert o.description == 'order 3 modified'
+ self.assert_sql_count(testbase.db, go, 1)
+ assert o._state.committed_state['description'] == 'order 3 modified'
+
+ o.isopen = 5
+ sess.expire(o, attribute_names=['description'])
+ assert 'id' in o.__dict__
+ assert 'description' not in o.__dict__
+ assert o.__dict__['isopen'] == 5
+ assert o._state.committed_state['isopen'] == 1
+
+ def go():
+ assert o.description == 'order 3 modified'
+ self.assert_sql_count(testbase.db, go, 1)
+ assert o.__dict__['isopen'] == 5
+ assert o._state.committed_state['description'] == 'order 3 modified'
+ assert o._state.committed_state['isopen'] == 1
+
+ sess.flush()
+
+ sess.expire(o, attribute_names=['id', 'isopen', 'description'])
+ assert 'id' not in o.__dict__
+ assert 'isopen' not in o.__dict__
+ assert 'description' not in o.__dict__
+ def go():
+ assert o.description == 'order 3 modified'
+ assert o.id == 3
+ assert o.isopen == 5
+ self.assert_sql_count(testbase.db, go, 1)
+
+ def test_partial_expire_lazy(self):
+ mapper(User, users, properties={
+ 'addresses':relation(Address, backref='user'),
+ })
+ mapper(Address, addresses)
+
+ sess = create_session()
+ u = sess.query(User).get(8)
+
+ sess.expire(u, ['name', 'addresses'])
+ assert 'name' not in u.__dict__
+ assert 'addresses' not in u.__dict__
+
+ # hit the lazy loader. just does the lazy load,
+ # doesnt do the overall refresh
+ def go():
+ assert u.addresses[0].email_address=='ed@wood.com'
+ self.assert_sql_count(testbase.db, go, 1)
+
+ assert 'name' not in u.__dict__
+
+ # check that mods to expired lazy-load attributes
+ # only do the lazy load
+ sess.expire(u, ['name', 'addresses'])
+ def go():
+ u.addresses = [Address(id=10, email_address='foo@bar.com')]
+ self.assert_sql_count(testbase.db, go, 1)
+
+ sess.flush()
+
+ # flush has occurred, and addresses was modified,
+ # so the addresses collection got committed and is
+ # longer expired
+ def go():
+ assert u.addresses[0].email_address=='foo@bar.com'
+ assert len(u.addresses) == 1
+ self.assert_sql_count(testbase.db, go, 0)
+
+ # but the name attribute was never loaded and so
+ # still loads
+ def go():
+ assert u.name == 'ed'
+ self.assert_sql_count(testbase.db, go, 1)
+
+ def test_partial_expire_eager(self):
+ mapper(User, users, properties={
+ 'addresses':relation(Address, backref='user', lazy=False),
+ })
+ mapper(Address, addresses)
+
+ sess = create_session()
+ u = sess.query(User).get(8)
+
+ sess.expire(u, ['name', 'addresses'])
+ assert 'name' not in u.__dict__
+ assert 'addresses' not in u.__dict__
+
+ def go():
+ assert u.addresses[0].email_address=='ed@wood.com'
+ self.assert_sql_count(testbase.db, go, 1)
+
+ # check that mods to expired eager-load attributes
+ # do the refresh
+ sess.expire(u, ['name', 'addresses'])
+ def go():
+ u.addresses = [Address(id=10, email_address='foo@bar.com')]
+ self.assert_sql_count(testbase.db, go, 1)
+ sess.flush()
+
+ # this should ideally trigger the whole load
+ # but currently it works like the lazy case
+ def go():
+ assert u.addresses[0].email_address=='foo@bar.com'
+ assert len(u.addresses) == 1
+ self.assert_sql_count(testbase.db, go, 0)
+
+ def go():
+ assert u.name == 'ed'
+ # scalar attributes have their own load
+ self.assert_sql_count(testbase.db, go, 1)
+ # ideally, this was already loaded, but we arent
+ # doing it that way right now
+ #self.assert_sql_count(testbase.db, go, 0)
+
+ def test_partial_expire_deferred(self):
+ mapper(Order, orders, properties={
+ 'description':deferred(orders.c.description)
+ })
+
+ sess = create_session()
+ o = sess.query(Order).get(3)
+ sess.expire(o, ['description', 'isopen'])
+ assert 'isopen' not in o.__dict__
+ assert 'description' not in o.__dict__
+
+ # test that expired attribute access refreshes
+ # the deferred
+ def go():
+ assert o.isopen == 1
+ assert o.description == 'order 3'
+ self.assert_sql_count(testbase.db, go, 1)
+
+ sess.expire(o, ['description', 'isopen'])
+ assert 'isopen' not in o.__dict__
+ assert 'description' not in o.__dict__
+ # test that the deferred attribute triggers the full
+ # reload
+ def go():
+ assert o.description == 'order 3'
+ assert o.isopen == 1
+ self.assert_sql_count(testbase.db, go, 1)
+
+ clear_mappers()
+
+ mapper(Order, orders)
+ sess.clear()
+
+ # same tests, using deferred at the options level
+ o = sess.query(Order).options(defer('description')).get(3)
+
+ assert 'description' not in o.__dict__
+
+ # sanity check
+ def go():
+ assert o.description == 'order 3'
+ self.assert_sql_count(testbase.db, go, 1)
+
+ assert 'description' in o.__dict__
+ assert 'isopen' in o.__dict__
+ sess.expire(o, ['description', 'isopen'])
+ assert 'isopen' not in o.__dict__
+ assert 'description' not in o.__dict__
+
+ # test that expired attribute access refreshes
+ # the deferred
+ def go():
+ assert o.isopen == 1
+ assert o.description == 'order 3'
+ self.assert_sql_count(testbase.db, go, 1)
+ sess.expire(o, ['description', 'isopen'])
+
+ assert 'isopen' not in o.__dict__
+ assert 'description' not in o.__dict__
+ # test that the deferred attribute triggers the full
+ # reload
+ def go():
+ assert o.description == 'order 3'
+ assert o.isopen == 1
+ self.assert_sql_count(testbase.db, go, 1)
+
+
+class RefreshTest(FixtureTest):
+ keep_mappers = False
+ refresh_data = True
+
+ def test_refresh(self):
+ mapper(User, users, properties={
+ 'addresses':relation(mapper(Address, addresses), backref='user')
+ })
+ s = create_session()
+ u = s.get(User, 7)
+ u.name = 'foo'
+ a = Address()
+ assert object_session(a) is None
+ u.addresses.append(a)
+ assert a.email_address is None
+ assert id(a) in [id(x) for x in u.addresses]
+
+ s.refresh(u)
+
+ # its refreshed, so not dirty
+ assert u not in s.dirty
+
+ # username is back to the DB
+ assert u.name == 'jack'
+
+ assert id(a) not in [id(x) for x in u.addresses]
+
+ u.name = 'foo'
+ u.addresses.append(a)
+ # now its dirty
+ assert u in s.dirty
+ assert u.name == 'foo'
+ assert id(a) in [id(x) for x in u.addresses]
+ s.expire(u)
+
+ # get the attribute, it refreshes
+ assert u.name == 'jack'
+ assert id(a) not in [id(x) for x in u.addresses]
+
+ def test_refresh_expired(self):
+ mapper(User, users)
+ s = create_session()
+ u = s.get(User, 7)
+ s.expire(u)
+ assert 'name' not in u.__dict__
+ s.refresh(u)
+ assert u.name == 'jack'
+
+ def test_refresh_with_lazy(self):
+ """test that when a lazy loader is set as a trigger on an object's attribute
+ (at the attribute level, not the class level), a refresh() operation doesnt
+ fire the lazy loader or create any problems"""
+
+ s = create_session()
+ mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))})
+ q = s.query(User).options(lazyload('addresses'))
+ u = q.filter(users.c.id==8).first()
+ def go():
+ s.refresh(u)
+ self.assert_sql_count(testbase.db, go, 1)
+
+
+ def test_refresh_with_eager(self):
+ """test that a refresh/expire operation loads rows properly and sends correct "isnew" state to eager loaders"""
+
+ mapper(User, users, properties={
+ 'addresses':relation(mapper(Address, addresses), lazy=False)
+ })
+
+ s = create_session()
+ u = s.get(User, 8)
+ assert len(u.addresses) == 3
+ s.refresh(u)
+ assert len(u.addresses) == 3
+
+ s = create_session()
+ u = s.get(User, 8)
+ assert len(u.addresses) == 3
+ s.expire(u)
+ assert len(u.addresses) == 3
+
+ @testing.fails_on('maxdb')
+ def test_refresh2(self):
+ """test a hang condition that was occuring on expire/refresh"""
+
+ s = create_session()
+ mapper(Address, addresses)
+
+ mapper(User, users, properties = dict(addresses=relation(Address,cascade="all, delete-orphan",lazy=False)) )
+
+ u=User()
+ u.name='Justin'
+ a = Address(id=10, email_address='lala')
+ u.addresses.append(a)
+
+ s.save(u)
+ s.flush()
+ s.clear()
+ u = s.query(User).filter(User.name=='Justin').one()
+
+ s.expire(u)
+ assert u.name == 'Justin'
+
+ s.refresh(u)
+
+if __name__ == '__main__':
+ testbase.main()
Column('bar_id', Integer, ForeignKey('bar.id')),
Column('data', String(20)))
- def create_test(polymorphic):
+ def create_test(polymorphic, name):
def test_get(self):
class Foo(object):
pass
assert sess.query(Blub).get(bl.id) == bl
self.assert_sql_count(testbase.db, go, 3)
-
+ test_get.__name__ = name
return test_get
- test_get_polymorphic = create_test(True)
- test_get_nonpolymorphic = create_test(False)
+ test_get_polymorphic = create_test(True, 'test_get_polymorphic')
+ test_get_nonpolymorphic = create_test(False, 'test_get_nonpolymorphic')
class ConstructionTest(ORMTest):
from testlib.fixtures import *
from query import QueryTest
-class LazyTest(QueryTest):
+class LazyTest(FixtureTest):
keep_mappers = False
-
- def setup_mappers(self):
- pass
-
+ keep_data = True
+
def test_basic(self):
mapper(User, users, properties={
'addresses':relation(mapper(Address, addresses), lazy=True)
assert a.user is u1
-class M2OGetTest(QueryTest):
+class M2OGetTest(FixtureTest):
keep_mappers = False
- keep_data = False
+ keep_data = True
- def setup_mappers(self):
- pass
-
def test_m2o_noload(self):
"""test that a NULL foreign key doesn't trigger a lazy load"""
mapper(User, users)
u2 = s.query(User).filter_by(user_name='jack').one()
assert u is u2
- def test_refresh(self):
- mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), backref='user')})
- s = create_session()
- u = s.get(User, 7)
- u.user_name = 'foo'
- a = Address()
- assert object_session(a) is None
- u.addresses.append(a)
-
- self.assert_(a in u.addresses)
-
- s.refresh(u)
-
- # its refreshed, so not dirty
- self.assert_(u not in s.dirty)
-
- # username is back to the DB
- self.assert_(u.user_name == 'jack')
-
- self.assert_(a not in u.addresses)
-
- u.user_name = 'foo'
- u.addresses.append(a)
- # now its dirty
- self.assert_(u in s.dirty)
- self.assert_(u.user_name == 'foo')
- self.assert_(a in u.addresses)
- s.expire(u)
-
- # get the attribute, it refreshes
- self.assert_(u.user_name == 'jack')
- self.assert_(a not in u.addresses)
def test_compileonsession(self):
m = mapper(User, users)
session = create_session()
session.connection(m)
- def test_expirecascade(self):
- mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), cascade="all, refresh-expire")})
- s = create_session()
- u = s.get(User, 8)
- u.addresses[0].email_address = 'someotheraddress'
- s.expire(u)
- assert u.addresses[0].email_address == 'ed@wood.com'
-
- def test_refreshwitheager(self):
- """test that a refresh/expire operation loads rows properly and sends correct "isnew" state to eager loaders"""
- mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)})
- s = create_session()
- u = s.get(User, 8)
- assert len(u.addresses) == 3
- s.refresh(u)
- assert len(u.addresses) == 3
-
- s = create_session()
- u = s.get(User, 8)
- assert len(u.addresses) == 3
- s.expire(u)
- assert len(u.addresses) == 3
-
def test_incompletecolumns(self):
"""test loading from a select which does not contain all columns"""
mapper(Address, addresses)
except Exception, e:
assert e is ex
- def test_refresh_lazy(self):
- """test that when a lazy loader is set as a trigger on an object's attribute (at the attribute level, not the class level), a refresh() operation doesnt fire the lazy loader or create any problems"""
- s = create_session()
- mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))})
- q2 = s.query(User).options(lazyload('addresses'))
- u = q2.selectfirst(users.c.user_id==8)
- def go():
- s.refresh(u)
- self.assert_sql_count(testbase.db, go, 1)
-
- def test_expire(self):
- """test the expire function"""
- s = create_session()
- mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)})
- u = s.get(User, 7)
- assert(len(u.addresses) == 1)
- u.user_name = 'foo'
- del u.addresses[0]
- s.expire(u)
- # test plain expire
- self.assert_(u.user_name =='jack')
- self.assert_(len(u.addresses) == 1)
-
- # we're changing the database here, so if this test fails in the middle,
- # it'll screw up the other tests which are hardcoded to 7/'jack'
- u.user_name = 'foo'
- s.flush()
- # change the value in the DB
- users.update(users.c.user_id==7, values=dict(user_name='jack')).execute()
- s.expire(u)
- # object isnt refreshed yet, using dict to bypass trigger
- self.assert_(u.__dict__.get('user_name') != 'jack')
- # do a select
- s.query(User).select()
- # test that it refreshed
- self.assert_(u.__dict__['user_name'] == 'jack')
-
- # object should be back to normal now,
- # this should *not* produce a SELECT statement (not tested here though....)
- self.assert_(u.user_name =='jack')
-
- @testing.fails_on('maxdb')
- def test_refresh2(self):
- """test a hang condition that was occuring on expire/refresh"""
-
- s = create_session()
- m1 = mapper(Address, addresses)
-
- m2 = mapper(User, users, properties = dict(addresses=relation(Address,private=True,lazy=False)) )
- u=User()
- u.user_name='Justin'
- a = Address()
- a.address_id=17 # to work around the hardcoded IDs in this test suite....
- u.addresses.append(a)
- s.flush()
- s.clear()
- u = s.query(User).selectfirst()
- print u.user_name
-
- #ok so far
- s.expire(u) #hangs when
- print u.user_name #this line runs
-
- s.refresh(u) #hangs
-
def test_props(self):
m = mapper(User, users, properties = {
'addresses' : relation(mapper(Address, addresses))
sess.save(u3)
sess.flush()
sess.rollback()
-
+
+ def test_illegal_non_primary(self):
+ mapper(User, users)
+ mapper(Address, addresses)
+ try:
+ mapper(User, users, non_primary=True, properties={
+ 'addresses':relation(Address)
+ }).compile()
+ assert False
+ except exceptions.ArgumentError, e:
+ assert "Attempting to assign a new relation 'addresses' to a non-primary mapper on class 'User'" in str(e)
+
def test_propfilters(self):
t = Table('person', MetaData(),
Column('id', Integer, primary_key=True),
class QueryTest(FixtureTest):
keep_mappers = True
keep_data = True
-
+
def setUpAll(self):
super(QueryTest, self).setUpAll()
- install_fixture_data()
self.setup_mappers()
- def tearDownAll(self):
- clear_mappers()
- super(QueryTest, self).tearDownAll()
-
def setup_mappers(self):
mapper(User, users, properties={
'addresses':relation(Address, backref='user'),
except exceptions.ConcurrentModificationError, e:
assert True
# reload it
+ print "RELOAD"
s1.query(Foo).load(f1s1.id)
# now assert version OK
+ print "VERSIONCHECK"
s1.query(Foo).with_lockmode('read').get(f1s1.id)
# assert brand new load is OK too
)
class FixtureTest(ORMTest):
+ refresh_data = False
+
+ def setUpAll(self):
+ super(FixtureTest, self).setUpAll()
+ if self.keep_data:
+ install_fixture_data()
+
+ def setUp(self):
+ if self.refresh_data:
+ install_fixture_data()
+
def define_tables(self, meta):
pass
FixtureTest.metadata = metadata