have each method called only once per operation, use the same
instance of the extension for both mappers.
[ticket:490]
+
+ - columns which are missing from a Query's select statement
+ now get automatically deferred during load.
+ - improved support for pickling of mapped entities. Per-instance
+ lazy/deferred/expired callables are now serializable so that
+ they serialize and deserialize with _state.
+
- new synonym() behavior: an attribute will be placed on the mapped
class, if one does not exist already, in all cases. if a property
already exists on the class, the synonym will decorate the property
class ScalarAttributeImpl(AttributeImpl):
"""represents a scalar value-holding InstrumentedAttribute."""
- accepts_global_callable = True
+ accepts_scalar_loader = True
def delete(self, state):
if self.key not in state.committed_state:
state.committed_state[self.key] = state.dict.get(self.key, NO_VALUE)
+ # TODO: catch key errors, convert to attributeerror?
del state.dict[self.key]
state.modified=True
Adds events to delete/set operations.
"""
- accepts_global_callable = False
+ accepts_scalar_loader = False
def __init__(self, class_, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
super(ScalarObjectAttributeImpl, self).__init__(class_, key,
def delete(self, state):
old = self.get(state)
+ # TODO: catch key errors, convert to attributeerror?
del state.dict[self.key]
self.fire_remove_event(state, old, self)
CollectionAdapter, a "view" onto that object that presents consistent
bag semantics to the orm layer independent of the user data implementation.
"""
- accepts_global_callable = False
+ accepts_scalar_loader = False
def __init__(self, class_, key, callable_, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
super(CollectionAttributeImpl, self).__init__(class_,
collection = self.get_collection(state)
collection.clear_with_event()
+ # TODO: catch key errors, convert to attributeerror?
del state.dict[self.key]
def initialize(self, state):
self.mappers = {}
self.attrs = {}
self.has_mutable_scalars = False
-
+
class InstanceState(object):
"""tracks state information at the instance level."""
self.dict = obj.__dict__
self.committed_state = {}
self.modified = False
- self.trigger = None
self.callables = {}
self.parents = {}
self.pending = {}
return None
def __getstate__(self):
- return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj()}
+ return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj(), 'expired_attributes':getattr(self, 'expired_attributes', None), 'callables':self.callables}
def __setstate__(self, state):
self.committed_state = state['committed_state']
self.obj = weakref.ref(state['instance'])
self.class_ = self.obj().__class__
self.dict = self.obj().__dict__
- self.callables = {}
- self.trigger = None
-
+ self.callables = state['callables']
+ self.runid = None
+ self.appenders = {}
+ if state['expired_attributes'] is not None:
+ self.expire_attributes(state['expired_attributes'])
+
def initialize(self, key):
getattr(self.class_, key).impl.initialize(self)
def set_callable(self, key, callable_):
self.dict.pop(key, None)
self.callables[key] = callable_
-
- def __fire_trigger(self):
+
+ def __call__(self):
+ """__call__ allows the InstanceState to act as a deferred
+ callable for loading expired attributes, which is also
+ serializable.
+ """
instance = self.obj()
- self.trigger(instance, [k for k in self.expired_attributes if k not in self.dict])
+ self.class_._class_state.deferred_scalar_loader(instance, [k for k in self.expired_attributes if k not in self.committed_state])
for k in self.expired_attributes:
self.callables.pop(k, None)
self.expired_attributes.clear()
return ATTR_WAS_SET
+ def unmodified(self):
+ """a set of keys which have no uncommitted changes"""
+
+ return util.Set([
+ attr.impl.key for attr in _managed_attributes(self.class_) if
+ attr.impl.key not in self.committed_state
+ and (not hasattr(attr.impl, 'commit_to_state') or not attr.impl.check_mutable_modified(self))
+ ])
+ unmodified = property(unmodified)
+
def expire_attributes(self, attribute_names):
if not hasattr(self, 'expired_attributes'):
self.expired_attributes = util.Set()
+
if attribute_names is None:
for attr in _managed_attributes(self.class_):
self.dict.pop(attr.impl.key, None)
- self.callables[attr.impl.key] = self.__fire_trigger
- self.expired_attributes.add(attr.impl.key)
+
+ if attr.impl.accepts_scalar_loader:
+ self.callables[attr.impl.key] = self
+ self.expired_attributes.add(attr.impl.key)
+
self.committed_state = {}
else:
for key in attribute_names:
self.dict.pop(key, None)
self.committed_state.pop(key, None)
- if not getattr(self.class_, key).impl.accepts_global_callable:
- continue
-
- self.callables[key] = self.__fire_trigger
- self.expired_attributes.add(key)
+ if getattr(self.class_, key).impl.accepts_scalar_loader:
+ self.callables[key] = self
+ self.expired_attributes.add(key)
def reset(self, key):
"""remove the given attribute and any callables associated with it."""
if not '_class_state' in class_.__dict__:
class_._class_state = ClassState()
-def register_class(class_, extra_init=None, on_exception=None):
+def register_class(class_, extra_init=None, on_exception=None, deferred_scalar_loader=None):
# do a sweep first, this also helps some attribute extensions
# (like associationproxy) become aware of themselves at the
# class level
getattr(class_, key, None)
_init_class_state(class_)
+ class_._class_state.deferred_scalar_loader=deferred_scalar_loader
oldinit = None
doinit = False
"""
from sqlalchemy import util, logging, exceptions
from sqlalchemy.sql import expression
+from itertools import chain
class_mapper = None
__all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension',
return prev + (mapper.base_mapper, key)
else:
return (mapper.base_mapper, key)
-
+
+def serialize_path(path):
+ if path is None:
+ return None
+
+ return [
+ (mapper.class_, mapper.entity_name, key)
+ for mapper, key in [(path[i], path[i+1]) for i in range(0, len(path)-1, 2)]
+ ]
+
+def deserialize_path(path):
+ if path is None:
+ return None
+
+ global class_mapper
+ if class_mapper is None:
+ from sqlalchemy.orm import class_mapper
+
+ return tuple(
+ chain(*[(class_mapper(cls, entity), key) for cls, entity, key in path])
+ )
class MapperOption(object):
"""Describe a modification to a Query."""
def on_exception(class_, oldinit, instance, args, kwargs):
util.warn_exception(self.extension.init_failed, self, class_, oldinit, instance, args, kwargs)
- attributes.register_class(self.class_, extra_init=extra_init, on_exception=on_exception)
+ attributes.register_class(self.class_, extra_init=extra_init, on_exception=on_exception, deferred_scalar_loader=_load_scalar_attributes)
self._class_state = self.class_._class_state
_mapper_registry[self] = True
instance._sa_session_id = context.session.hash_key
session_identity_map[identitykey] = instance
- if currentload or context.populate_existing or self.always_refresh or state.trigger:
+ if currentload or context.populate_existing or self.always_refresh:
if isnew:
state.runid = context.runid
- state.trigger = None
context.progress.add(state)
-
+
if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
self.populate_instance(context, instance, row, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew)
-
+
+ elif getattr(state, 'expired_attributes', None):
+ if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=state.expired_attributes, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
+ self.populate_instance(context, instance, row, only_load_props=state.expired_attributes, instancekey=identitykey, isnew=isnew)
+
if result is not None and ('append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE):
result.append(instance)
return instance
-
- def _deferred_inheritance_condition(self, base_mapper, needs_tables):
- def visit_binary(binary):
- leftcol = binary.left
- rightcol = binary.right
- if leftcol is None or rightcol is None:
- return
- if leftcol.table not in needs_tables:
- binary.left = sql.bindparam(None, None, type_=binary.right.type)
- param_names.append((leftcol, binary.left))
- elif rightcol not in needs_tables:
- binary.right = sql.bindparam(None, None, type_=binary.right.type)
- param_names.append((rightcol, binary.right))
-
- allconds = []
- param_names = []
-
- for mapper in self.iterate_to_root():
- if mapper is base_mapper:
- break
- allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
-
- return sql.and_(*allconds), param_names
def translate_row(self, tomapper, row):
"""Translate the column keys of a row into a new or proxied
populators = new_populators
else:
populators = existing_populators
-
+
+ if only_load_props:
+ populators = [p for p in populators if p[0] in only_load_props]
+
for (key, populator) in populators:
selectcontext.exec_with_path(self, key, populator, instance, row, ispostselect=ispostselect, isnew=isnew, **flags)
p(state.obj())
def _get_poly_select_loader(self, selectcontext, row):
- # 'select' or 'union'+col not present
+ """set up attribute loaders for 'select' and 'deferred' polymorphic loading.
+
+ this loading uses a second SELECT statement to load additional tables,
+ either immediately after loading the main table or via a deferred attribute trigger.
+ """
+
(hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', self), (None, None))
- if hosted_mapper is None or not needs_tables or hosted_mapper.polymorphic_fetch == 'deferred':
+
+ if hosted_mapper is None or not needs_tables:
return
cond, param_names = self._deferred_inheritance_condition(hosted_mapper, needs_tables)
statement = sql.select(needs_tables, cond, use_labels=True)
- def post_execute(instance, **flags):
- if self.__should_log_debug:
- self.__log_debug("Post query loading instance " + instance_str(instance))
+
+ if hosted_mapper.polymorphic_fetch == 'select':
+ def post_execute(instance, **flags):
+ if self.__should_log_debug:
+ self.__log_debug("Post query loading instance " + instance_str(instance))
+
+ identitykey = self.identity_key_from_instance(instance)
+
+ params = {}
+ for c, bind in param_names:
+ params[bind] = self._get_attr_by_column(instance, c)
+ row = selectcontext.session.connection(self).execute(statement, params).fetchone()
+ self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True)
+ return post_execute
+ elif hosted_mapper.polymorphic_fetch == 'deferred':
+ from sqlalchemy.orm.strategies import DeferredColumnLoader
+
+ def post_execute(instance, **flags):
+ def create_statement(instance):
+ params = {}
+ for (c, bind) in param_names:
+ # use the "committed" (database) version to get query column values
+ params[bind] = self._get_committed_attr_by_column(instance, c)
+ return (statement, params)
+
+ props = [prop for prop in [self._get_col_to_prop(col) for col in statement.inner_columns] if prop.key not in instance.__dict__]
+ keys = [p.key for p in props]
+ for prop in props:
+ strategy = prop._get_strategy(DeferredColumnLoader)
+ instance._state.set_callable(prop.key, strategy.setup_loader(instance, props=keys, create_statement=create_statement))
+ return post_execute
+ else:
+ return None
+
+ def _deferred_inheritance_condition(self, base_mapper, needs_tables):
+ def visit_binary(binary):
+ leftcol = binary.left
+ rightcol = binary.right
+ if leftcol is None or rightcol is None:
+ return
+ if leftcol.table not in needs_tables:
+ binary.left = sql.bindparam(None, None, type_=binary.right.type)
+ param_names.append((leftcol, binary.left))
+ elif rightcol not in needs_tables:
+ binary.right = sql.bindparam(None, None, type_=binary.right.type)
+ param_names.append((rightcol, binary.right))
- identitykey = self.identity_key_from_instance(instance)
+ allconds = []
+ param_names = []
- params = {}
- for c, bind in param_names:
- params[bind] = self._get_attr_by_column(instance, c)
- row = selectcontext.session.connection(self).execute(statement, params).fetchone()
- self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True)
+ for mapper in self.iterate_to_root():
+ if mapper is base_mapper:
+ break
+ allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
- return post_execute
+ return sql.and_(*allconds), param_names
Mapper.logger = logging.class_logger(Mapper)
return hasattr(object, '_entity_name')
+object_session = None
+
+def _load_scalar_attributes(instance, attribute_names):
+ global object_session
+ if not object_session:
+ from sqlalchemy.orm.session import object_session
+
+ if object_session(instance).query(object_mapper(instance))._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None:
+ raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % instance_str(instance))
+
def _state_mapper(state, entity_name=None):
return state.class_._class_state.mappers[state.dict.get('_entity_name', entity_name)]
return util.IdentitySet(self.uow.new.values())
new = property(new)
-
+
def _expire_state(state, attribute_names):
"""Standalone expire instance function.
If the list is None or blank, the entire instance is expired.
"""
- if state.trigger is None:
- def load_attributes(instance, attribute_names):
- if object_session(instance).query(instance.__class__)._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None:
- raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
- state.trigger = load_attributes
-
state.expire_attributes(attribute_names)
register_attribute = unitofwork.register_attribute
from sqlalchemy.sql import util as sql_util
from sqlalchemy.sql import visitors, expression, operators
from sqlalchemy.orm import mapper, attributes
-from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption
+from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption, serialize_path, deserialize_path
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
if self._should_log_debug:
self.logger.debug("Returning active column fetcher for %s %s" % (mapper, self.key))
return (new_execute, None, None)
-
- # our mapped column is not present in the row. check if we need to initialize a polymorphic
- # row fetcher used by inheritance.
- (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', mapper), (None, None))
-
- if hosted_mapper is None:
- return (None, None, None)
-
- if hosted_mapper.polymorphic_fetch == 'deferred':
- # 'deferred' polymorphic row fetcher, put a callable on the property.
- # create a deferred column loader which will query the remaining not-yet-loaded tables in an inheritance load.
- # the mapper for the object creates the WHERE criterion using the mapper who originally
- # "hosted" the query and the list of tables which are unloaded between the "hosted" mapper
- # and this mapper. (i.e. A->B->C, the query used mapper A. therefore will need B's and C's tables
- # in the query).
-
- # deferred loader strategy
- strategy = self.parent_property._get_strategy(DeferredColumnLoader)
-
- # full list of ColumnProperty objects to be loaded in the deferred fetch
- props = [p.key for p in mapper.iterate_properties if isinstance(p.strategy, ColumnLoader) and p.columns[0].table in needs_tables]
-
- # TODO: we are somewhat duplicating efforts from mapper._get_poly_select_loader
- # and should look for ways to simplify.
- cond, param_names = mapper._deferred_inheritance_condition(hosted_mapper, needs_tables)
- statement = sql.select(needs_tables, cond, use_labels=True)
- def create_statement(instance):
- params = {}
- for (c, bind) in param_names:
- # use the "committed" (database) version to get query column values
- params[bind] = mapper._get_committed_attr_by_column(instance, c)
- return (statement, params)
-
+ else:
def new_execute(instance, row, isnew, **flags):
if isnew:
- instance._state.set_callable(self.key, strategy.setup_loader(instance, props=props, create_statement=create_statement))
-
+ instance._state.expire_attributes([self.key])
if self._should_log_debug:
- self.logger.debug("Returning deferred column fetcher for %s %s" % (mapper, self.key))
-
+ self.logger.debug("Deferring load for %s %s" % (mapper, self.key))
return (new_execute, None, None)
- else:
- # immediate polymorphic row fetcher. no processing needed for this row.
- if self._should_log_debug:
- self.logger.debug("Returning no column fetcher for %s %s" % (mapper, self.key))
- return (None, None, None)
-
ColumnLoader.logger = logging.class_logger(ColumnLoader)
self.parent_property._get_strategy(ColumnLoader).setup_query(context, **kwargs)
def setup_loader(self, instance, props=None, create_statement=None):
- localparent = mapper.object_mapper(instance, raiseerror=False)
- if localparent is None:
+ if not mapper.has_mapper(instance):
return None
+
+ localparent = mapper.object_mapper(instance)
# adjust for the ColumnProperty associated with the instance
# not being our own ColumnProperty. This can occur when entity_name
prop = localparent.get_property(self.key)
if prop is not self.parent_property:
return prop._get_strategy(DeferredColumnLoader).setup_loader(instance)
-
- def lazyload():
- if not mapper.has_identity(instance):
- return None
-
- if props is not None:
- group = props
- elif self.group is not None:
- group = [p.key for p in localparent.iterate_properties if isinstance(p.strategy, DeferredColumnLoader) and p.group==self.group]
- else:
- group = [self.parent_property.key]
-
- # narrow the keys down to just those which aren't present on the instance
- group = [k for k in group if k not in instance.__dict__]
-
- if self._should_log_debug:
- self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), group and ','.join(group) or 'None'))
-
- session = sessionlib.object_session(instance)
- if session is None:
- raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
- if create_statement is None:
- ident = instance._instance_key[1]
- session.query(localparent)._get(None, ident=ident, only_load_props=group, refresh_instance=instance._state)
- else:
- statement, params = create_statement(instance)
- session.query(localparent).from_statement(statement).params(params)._get(None, only_load_props=group, refresh_instance=instance._state)
- return attributes.ATTR_WAS_SET
- return lazyload
+ return LoadDeferredColumns(instance, self.key, props, optimizing_statement=create_statement)
DeferredColumnLoader.logger = logging.class_logger(DeferredColumnLoader)
+class LoadDeferredColumns(object):
+ """callable, serializable loader object used by DeferredColumnLoader"""
+
+ def __init__(self, instance, key, keys, optimizing_statement):
+ self.instance = instance
+ self.key = key
+ self.keys = keys
+ self.optimizing_statement = optimizing_statement
+
+ def __getstate__(self):
+ return {'instance':self.instance, 'key':self.key, 'keys':self.keys}
+
+ def __setstate__(self, state):
+ self.instance = state['instance']
+ self.key = state['key']
+ self.keys = state['keys']
+ self.optimizing_statement = None
+
+ def __call__(self):
+ if not mapper.has_identity(self.instance):
+ return None
+
+ localparent = mapper.object_mapper(self.instance, raiseerror=False)
+
+ prop = localparent.get_property(self.key)
+ strategy = prop._get_strategy(DeferredColumnLoader)
+
+ if self.keys:
+ toload = self.keys
+ elif strategy.group:
+ toload = [p.key for p in localparent.iterate_properties if isinstance(p.strategy, DeferredColumnLoader) and p.group==strategy.group]
+ else:
+ toload = [self.key]
+
+ # narrow the keys down to just those which have no history
+ group = [k for k in toload if k in self.instance._state.unmodified]
+
+ if strategy._should_log_debug:
+ strategy.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(self.instance, self.key), group and ','.join(group) or 'None'))
+
+ session = sessionlib.object_session(self.instance)
+ if session is None:
+ raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (self.instance.__class__, self.key))
+
+ query = session.query(localparent)
+ if not self.optimizing_statement:
+ ident = self.instance._instance_key[1]
+ query._get(None, ident=ident, only_load_props=group, refresh_instance=self.instance._state)
+ else:
+ statement, params = self.optimizing_statement(self.instance)
+ query.from_statement(statement).params(params)._get(None, only_load_props=group, refresh_instance=self.instance._state)
+ return attributes.ATTR_WAS_SET
+
class DeferredOption(StrategizedOption):
def __init__(self, key, defer=False):
super(DeferredOption, self).__init__(key)
class LazyLoader(AbstractRelationLoader):
def init(self):
super(LazyLoader, self).init()
- (self.lazywhere, self.lazybinds, self.lazyreverse) = self._create_lazy_clause(self)
+ (self.lazywhere, self.lazybinds, self.equated_columns) = self._create_lazy_clause(self)
self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.lazywhere))
def lazy_clause(self, instance, reverse_direction=False):
if instance is None:
- return self.lazy_none_clause(reverse_direction)
+ return self._lazy_none_clause(reverse_direction)
if not reverse_direction:
- (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.lazyreverse)
+ (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.equated_columns)
else:
(criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
bindparam.value = mapper._get_committed_attr_by_column(instance, bind_to_col[bindparam.key])
return visitors.traverse(criterion, clone=True, visit_bindparam=visit_bindparam)
- def lazy_none_clause(self, reverse_direction=False):
+ def _lazy_none_clause(self, reverse_direction=False):
if not reverse_direction:
- (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.lazyreverse)
+ (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.equated_columns)
else:
(criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
def setup_loader(self, instance, options=None, path=None):
if not mapper.has_mapper(instance):
return None
- else:
- # adjust for the PropertyLoader associated with the instance
- # not being our own PropertyLoader. This can occur when entity_name
- # mappers are used to map different versions of the same PropertyLoader
- # to the class.
- prop = mapper.object_mapper(instance).get_property(self.key)
- if prop is not self.parent_property:
- return prop._get_strategy(LazyLoader).setup_loader(instance)
-
- def lazyload():
- if self._should_log_debug:
- self.logger.debug("lazy load attribute %s on instance %s" % (self.key, mapperutil.instance_str(instance)))
- if not mapper.has_identity(instance):
- return None
+ localparent = mapper.object_mapper(instance)
- session = sessionlib.object_session(instance)
- if session is None:
- try:
- session = mapper.object_mapper(instance).get_session()
- except exceptions.InvalidRequestError:
- raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
-
- # if we have a simple straight-primary key load, use mapper.get()
- # to possibly save a DB round trip
- q = session.query(self.mapper).autoflush(False)
- if path:
- q = q._with_current_path(path)
- if self.use_get:
- params = {}
- for col, bind in self.lazybinds.iteritems():
- # use the "committed" (database) version to get query column values
- params[bind.key] = self.parent._get_committed_attr_by_column(instance, col)
- ident = []
- nonnulls = False
- for primary_key in self.select_mapper.primary_key:
- bind = self.lazyreverse[primary_key]
- v = params[bind.key]
- if v is not None:
- nonnulls = True
- ident.append(v)
- if not nonnulls:
- return None
- if options:
- q = q._conditional_options(*options)
- return q.get(ident)
- elif self.order_by is not False:
- q = q.order_by(self.order_by)
- elif self.secondary is not None and self.secondary.default_order_by() is not None:
- q = q.order_by(self.secondary.default_order_by())
-
- if options:
- q = q._conditional_options(*options)
- q = q.filter(self.lazy_clause(instance))
-
- result = q.all()
- if self.uselist:
- return result
- else:
- if result:
- return result[0]
- else:
- return None
-
- return lazyload
+ # adjust for the PropertyLoader associated with the instance
+ # not being our own PropertyLoader. This can occur when entity_name
+ # mappers are used to map different versions of the same PropertyLoader
+ # to the class.
+ prop = localparent.get_property(self.key)
+ if prop is not self.parent_property:
+ return prop._get_strategy(LazyLoader).setup_loader(instance)
+
+ return LoadLazyAttribute(instance, self.key, options, path)
def create_row_processor(self, selectcontext, mapper, row):
if not self.is_class_level or len(selectcontext.options):
(primaryjoin, secondaryjoin, remote_side) = (prop.polymorphic_primaryjoin, prop.polymorphic_secondaryjoin, prop.remote_side)
binds = {}
- reverse = {}
+ equated_columns = {}
def should_bind(targetcol, othercol):
if reverse_direction and not secondaryjoin:
return
leftcol = binary.left
rightcol = binary.right
-
+
+ equated_columns[rightcol] = leftcol
+ equated_columns[leftcol] = rightcol
+
if should_bind(leftcol, rightcol):
- col = leftcol
- binary.left = binds.setdefault(leftcol,
- sql.bindparam(None, None, type_=binary.right.type))
- reverse[rightcol] = binds[col]
+ binary.left = binds[leftcol] = sql.bindparam(None, None, type_=binary.right.type)
# the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1",
# which can happen in rare cases (test/orm/relationships.py RelationTest2)
if leftcol is not rightcol and should_bind(rightcol, leftcol):
- col = rightcol
- binary.right = binds.setdefault(rightcol,
- sql.bindparam(None, None, type_=binary.left.type))
- reverse[leftcol] = binds[col]
+ binary.right = binds[rightcol] = sql.bindparam(None, None, type_=binary.left.type)
lazywhere = primaryjoin
if reverse_direction:
secondaryjoin = visitors.traverse(secondaryjoin, clone=True, visit_binary=visit_binary)
lazywhere = sql.and_(lazywhere, secondaryjoin)
- return (lazywhere, binds, reverse)
+ return (lazywhere, binds, equated_columns)
_create_lazy_clause = classmethod(_create_lazy_clause)
LazyLoader.logger = logging.class_logger(LazyLoader)
+class LoadLazyAttribute(object):
+ """callable, serializable loader object used by LazyLoader"""
+
+ def __init__(self, instance, key, options, path):
+ self.instance = instance
+ self.key = key
+ self.options = options
+ self.path = path
+
+ def __getstate__(self):
+ return {'instance':self.instance, 'key':self.key, 'options':self.options, 'path':serialize_path(self.path)}
+
+ def __setstate__(self, state):
+ self.instance = state['instance']
+ self.key = state['key']
+ self.options= state['options']
+ self.path = deserialize_path(state['path'])
+
+ def __call__(self):
+ instance = self.instance
+
+ if not mapper.has_identity(instance):
+ return None
+
+ instance_mapper = mapper.object_mapper(instance)
+ prop = instance_mapper.get_property(self.key)
+ strategy = prop._get_strategy(LazyLoader)
+
+ if strategy._should_log_debug:
+ strategy.logger.debug("lazy load attribute %s on instance %s" % (self.key, mapperutil.instance_str(instance)))
+
+ session = sessionlib.object_session(instance)
+ if session is None:
+ try:
+ session = instance_mapper.get_session()
+ except exceptions.InvalidRequestError:
+ raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
+
+ q = session.query(prop.mapper).autoflush(False)
+ if self.path:
+ q = q._with_current_path(self.path)
+
+ # if we have a simple primary key load, use mapper.get()
+ # to possibly save a DB round trip
+ if strategy.use_get:
+ ident = []
+ allnulls = True
+ for primary_key in prop.select_mapper.primary_key:
+ val = instance_mapper._get_committed_attr_by_column(instance, strategy.equated_columns[primary_key])
+ allnulls = allnulls and val is None
+ ident.append(val)
+ if allnulls:
+ return None
+ if self.options:
+ q = q._conditional_options(*self.options)
+ return q.get(ident)
+
+ if strategy.order_by is not False:
+ q = q.order_by(strategy.order_by)
+ elif strategy.secondary is not None and strategy.secondary.default_order_by() is not None:
+ q = q.order_by(strategy.secondary.default_order_by())
+
+ if self.options:
+ q = q._conditional_options(*self.options)
+ q = q.filter(strategy.lazy_clause(instance))
+
+ result = q.all()
+ if strategy.uselist:
+ return result
+ else:
+ if result:
+ return result[0]
+ else:
+ return None
+
class EagerLoader(AbstractRelationLoader):
"""Loads related objects inline with a parent query."""
if self._should_log_debug:
self.logger.debug("eager loader %s degrading to lazy loader" % str(self))
return self.parent_property._get_strategy(LazyLoader).create_row_processor(selectcontext, mapper, row)
-
-
+
def __str__(self):
return str(self.parent) + "." + self.key
def state_str(state):
"""Return a string describing an instance."""
-
- return state.class_.__name__ + "@" + hex(id(state.obj()))
+ if state is None:
+ return "None"
+ else:
+ return state.class_.__name__ + "@" + hex(id(state.obj()))
def attribute_str(instance, attribute):
return instance_str(instance) + "." + attribute
'subquery', 'table', 'text', 'union', 'union_all', 'update', ]
-BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
def desc(column):
"""Return a descending ``ORDER BY`` clause element.
__visit_name__ = 'textclause'
+ _bind_params_regex = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
+
def __init__(self, text = "", bind=None, bindparams=None, typemap=None):
self._bind = bind
self.bindparams = {}
# scan the string and search for bind parameter names, add them
# to the list of bindparams
- self.text = BIND_PARAMS.sub(repl, text)
+ self.text = self._bind_params_regex.sub(repl, text)
if bindparams is not None:
for b in bindparams:
self.bindparams[b.key] = b
'orm.relationships',
'orm.association',
'orm.merge',
+ 'orm.pickled',
'orm.memusage',
'orm.cycles',
self.assert_(o4.mt2[0].a == 'abcde')
self.assert_(o4.mt2[0].b is None)
+ def test_deferred(self):
+ class Foo(object):pass
+
+ data = {'a':'this is a', 'b':12}
+ def loader(instance, keys):
+ for k in keys:
+ instance.__dict__[k] = data[k]
+ return attributes.ATTR_WAS_SET
+
+ attributes.register_class(Foo, deferred_scalar_loader=loader)
+ attributes.register_attribute(Foo, 'a', uselist=False, useobject=False)
+ attributes.register_attribute(Foo, 'b', uselist=False, useobject=False)
+
+ f = Foo()
+ f._state.expire_attributes(None)
+ self.assertEquals(f.a, "this is a")
+ self.assertEquals(f.b, 12)
+
+ f.a = "this is some new a"
+ f._state.expire_attributes(None)
+ self.assertEquals(f.a, "this is a")
+ self.assertEquals(f.b, 12)
+
+ f._state.expire_attributes(None)
+ f.a = "this is another new a"
+ self.assertEquals(f.a, "this is another new a")
+ self.assertEquals(f.b, 12)
+
+ f._state.expire_attributes(None)
+ self.assertEquals(f.a, "this is a")
+ self.assertEquals(f.b, 12)
+
+ del f.a
+ self.assertEquals(f.a, None)
+ self.assertEquals(f.b, 12)
+
+ f._state.commit_all()
+ self.assertEquals(f.a, None)
+ self.assertEquals(f.b, 12)
+
+ def test_deferred_pickleable(self):
+ data = {'a':'this is a', 'b':12}
+ def loader(instance, keys):
+ for k in keys:
+ instance.__dict__[k] = data[k]
+ return attributes.ATTR_WAS_SET
+
+ attributes.register_class(MyTest, deferred_scalar_loader=loader)
+ attributes.register_attribute(MyTest, 'a', uselist=False, useobject=False)
+ attributes.register_attribute(MyTest, 'b', uselist=False, useobject=False)
+
+ m = MyTest()
+ m._state.expire_attributes(None)
+ assert 'a' not in m.__dict__
+ m2 = pickle.loads(pickle.dumps(m))
+ assert 'a' not in m2.__dict__
+ self.assertEquals(m2.a, "this is a")
+ self.assertEquals(m2.b, 12)
+
def test_list(self):
class User(object):pass
class Address(object):pass
self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [], []))
lazy_load = [bar1, bar2, bar3]
- f._state.trigger = lazyload(f)
f._state.expire_attributes(['bars'])
self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar2, bar3], []))
self.assert_sql_count(testbase.db, go, 1)
assert 'name' in u.__dict__
- # we're changing the database here, so if this test fails in the middle,
- # it'll screw up the other tests which are hardcoded to 7/'jack'
u.name = 'foo'
sess.flush()
# change the value in the DB
# test that it refreshed
assert u.__dict__['name'] == 'jack'
- # object should be back to normal now,
- # this should *not* produce a SELECT statement (not tested here though....)
- assert u.name == 'jack'
+ def go():
+ assert u.name == 'jack'
+ self.assert_sql_count(testbase.db, go, 0)
def test_expire_doesntload_on_set(self):
mapper(User, users)
assert o.isopen == 1
self.assert_sql_count(testbase.db, go, 1)
assert o.description == 'order 3 modified'
+
+ del o.description
+ assert "description" not in o.__dict__
+ sess.expire(o, ['isopen'])
+ sess.query(Order).all()
+ assert o.isopen == 1
+ assert "description" not in o.__dict__
+
+ assert o.description is None
def test_expire_committed(self):
"""test that the committed state of the attribute receives the most recent DB data"""
def go():
assert u.addresses[0].email_address == 'jack@bean.com'
assert u.name == 'jack'
- # one load
- self.assert_sql_count(testbase.db, go, 1)
+ # two loads, since relation() + scalar are
+ # separate right now
+ self.assert_sql_count(testbase.db, go, 2)
assert 'name' in u.__dict__
assert 'addresses' in u.__dict__
+ sess.expire(u, ['name', 'addresses'])
+ assert 'name' not in u.__dict__
+ assert 'addresses' not in u.__dict__
+
def test_partial_expire(self):
mapper(Order, orders)
s.expire(u)
# get the attribute, it refreshes
+ print "OK------"
+# print u.__dict__
+# print u._state.callables
assert u.name == 'jack'
assert id(a) not in [id(x) for x in u.addresses]
a = s.query(Address).from_statement(select([addresses.c.address_id, addresses.c.user_id])).first()
assert a.user_id == 7
assert a.address_id == 1
- assert a.email_address is None
+ # email address auto-defers
+ assert 'email_addres' not in a.__dict__
+ assert a.email_address == 'jack@bean.com'
def test_badconstructor(self):
"""test that if the construction of a mapped class fails, the instnace does not get placed in the session"""
--- /dev/null
+import testbase
+from sqlalchemy import *
+from sqlalchemy import exceptions
+from sqlalchemy.orm import *
+from testlib import *
+from testlib.fixtures import *
+import pickle
+
+class EmailUser(User):
+ pass
+
+class PickleTest(FixtureTest):
+ keep_mappers = False
+ keep_data = False
+
+ def test_transient(self):
+ mapper(User, users, properties={
+ 'addresses':relation(Address, backref="user")
+ })
+ mapper(Address, addresses)
+
+ sess = create_session()
+ u1 = User(name='ed')
+ u1.addresses.append(Address(email_address='ed@bar.com'))
+
+ u2 = pickle.loads(pickle.dumps(u1))
+ sess.save(u2)
+ sess.flush()
+
+ sess.clear()
+
+ self.assertEquals(u1, sess.query(User).get(u2.id))
+
+ def test_class_deferred_cols(self):
+ mapper(User, users, properties={
+ 'name':deferred(users.c.name),
+ 'addresses':relation(Address, backref="user")
+ })
+ mapper(Address, addresses, properties={
+ 'email_address':deferred(addresses.c.email_address)
+ })
+ sess = create_session()
+ u1 = User(name='ed')
+ u1.addresses.append(Address(email_address='ed@bar.com'))
+ sess.save(u1)
+ sess.flush()
+ sess.clear()
+ u1 = sess.query(User).get(u1.id)
+ assert 'name' not in u1.__dict__
+ assert 'addresses' not in u1.__dict__
+
+ u2 = pickle.loads(pickle.dumps(u1))
+ sess2 = create_session()
+ sess2.update(u2)
+ self.assertEquals(u2.name, 'ed')
+ self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
+
+ def test_instance_deferred_cols(self):
+ mapper(User, users, properties={
+ 'addresses':relation(Address, backref="user")
+ })
+ mapper(Address, addresses)
+
+ sess = create_session()
+ u1 = User(name='ed')
+ u1.addresses.append(Address(email_address='ed@bar.com'))
+ sess.save(u1)
+ sess.flush()
+ sess.clear()
+
+ u1 = sess.query(User).options(defer('name'), defer('addresses.email_address')).get(u1.id)
+ assert 'name' not in u1.__dict__
+ assert 'addresses' not in u1.__dict__
+
+ u2 = pickle.loads(pickle.dumps(u1))
+ sess2 = create_session()
+ sess2.update(u2)
+ self.assertEquals(u2.name, 'ed')
+ assert 'addresses' not in u1.__dict__
+ ad = u2.addresses[0]
+ assert 'email_address' not in ad.__dict__
+ self.assertEquals(ad.email_address, 'ed@bar.com')
+ self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
+
+class PolymorphicDeferredTest(ORMTest):
+ def define_tables(self, metadata):
+ global users, email_users
+ users = Table('users', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(30)),
+ Column('type', String(30)),
+ )
+ email_users = Table('email_users', metadata,
+ Column('id', Integer, ForeignKey('users.id'), primary_key=True),
+ Column('email_address', String(30))
+ )
+
+ def test_polymorphic_deferred(self):
+ mapper(User, users, polymorphic_identity='user', polymorphic_on=users.c.type, polymorphic_fetch='deferred')
+ mapper(EmailUser, email_users, inherits=User, polymorphic_identity='emailuser')
+
+ eu = EmailUser(name="user1", email_address='foo@bar.com')
+ sess = create_session()
+ sess.save(eu)
+ sess.flush()
+ sess.clear()
+
+ eu = sess.query(User).first()
+ eu2 = pickle.loads(pickle.dumps(eu))
+ sess2 = create_session()
+ sess2.update(eu2)
+ assert 'email_address' not in eu2.__dict__
+ self.assertEquals(eu2.email_address, 'foo@bar.com')
+
+
+
+
+if __name__ == '__main__':
+ testbase.main()
\ No newline at end of file