all MapperOptions process the Query and that's it, one very simpliied QueryContext object gets passed
around at query.compile() and query.instances() time
- slight optimization to MapperExtension allowing the mapper to check for the presence of an extended method, takes 3000 calls off of masseagerload.py test (only a slight increase in speed though)
- attempting to centralize the notion of a "path" along mappers/properties, need to define what that is better. heading towards [ticket:777]...
__all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension',
'MapperProperty', 'PropComparator', 'StrategizedProperty',
- 'LoaderStack', 'OperationContext', 'MapperOption',
+ 'LoaderStack', 'build_path', 'MapperOption',
'ExtensionOption', 'SynonymProperty', 'PropertyOption',
'AttributeExtension', 'StrategizedOption', 'LoaderStrategy' ]
if self.is_primary():
self.strategy.init_class_attribute()
+def build_path(mapper, key, prev=None):
+ if prev:
+ return prev + (mapper.base_mapper, key)
+ else:
+ return (mapper.base_mapper, key)
+
class LoaderStack(object):
"""a stack object used during load operations to track the
current position among a chain of mappers to eager loaders."""
def __str__(self):
return "->".join([str(s) for s in self.__stack])
-
-class OperationContext(object):
- """Serve as a context during a query construction or instance
- loading operation.
- Accept ``MapperOption`` objects which may modify its state before proceeding.
- """
-
- def __init__(self, mapper, options, attributes=None):
- self.mapper = mapper
- self.options = options
- self.attributes = attributes or {}
- self.recursion_stack = util.Set()
- for opt in util.flatten_iterator(options):
- self.accept_option(opt)
-
- def accept_option(self, opt):
- pass
class MapperOption(object):
- """Describe a modification to an OperationContext or Query."""
-
- def process_query_context(self, context):
- pass
-
- def process_selection_context(self, context):
- pass
+ """Describe a modification to a Query."""
def process_query(self, query):
pass
def __init__(self, key):
self.key = key
- def process_query_property(self, context, properties):
- pass
+ def process_query(self, query):
+ self.process_query_property(query, self._get_properties(query))
- def process_selection_property(self, context, properties):
+ def process_query_property(self, query, properties):
pass
- def process_query_context(self, context):
- self.process_query_property(context, self._get_properties(context))
-
- def process_selection_context(self, context):
- self.process_selection_property(context, self._get_properties(context))
-
- def _get_properties(self, context):
+ def _get_properties(self, query):
try:
l = self.__prop
except AttributeError:
l = []
- mapper = context.mapper
+ mapper = query.mapper
for token in self.key.split('.'):
prop = mapper.get_property(token, resolve_synonyms=True)
l.append(prop)
def is_chained(self):
return False
- def process_query_property(self, context, properties):
- self.logger.debug("applying option to QueryContext, property key '%s'" % self.key)
+ def process_query_property(self, query, properties):
+ self.logger.debug("applying option to Query, property key '%s'" % self.key)
if self.is_chained():
for prop in properties:
- context.attributes[("loaderstrategy", prop)] = self.get_strategy_class()
+ query._attributes[("loaderstrategy", prop)] = self.get_strategy_class()
else:
- context.attributes[("loaderstrategy", properties[-1])] = self.get_strategy_class()
-
- def process_selection_property(self, context, properties):
- self.logger.debug("applying option to SelectionContext, property key '%s'" % self.key)
- if self.is_chained():
- for prop in properties:
- context.attributes[("loaderstrategy", prop)] = self.get_strategy_class()
- else:
- context.attributes[("loaderstrategy", properties[-1])] = self.get_strategy_class()
+ query._attributes[("loaderstrategy", properties[-1])] = self.get_strategy_class()
def get_strategy_class(self):
raise NotImplementedError()
def extra_init(class_, oldinit, instance, args, kwargs):
self.compile()
- self.extension.init_instance(self, class_, oldinit, instance, args, kwargs)
+ if 'init_instance' in self.extension.methods:
+ self.extension.init_instance(self, class_, oldinit, instance, args, kwargs)
def on_exception(class_, oldinit, instance, args, kwargs):
util.warn_exception(self.extension.init_failed, self, class_, oldinit, instance, args, kwargs)
Raise ``InvalidRequestError`` if a session cannot be retrieved
from the extension chain.
"""
+
+ if 'get_session' in self.extension.methods:
+ s = self.extension.get_session()
+ if s is not EXT_CONTINUE:
+ return s
- s = self.extension.get_session()
- if s is EXT_CONTINUE:
- raise exceptions.InvalidRequestError("No contextual Session is established. Use a MapperExtension that implements get_session or use 'import sqlalchemy.mods.threadlocal' to establish a default thread-local contextual session.")
- return s
-
+ raise exceptions.InvalidRequestError("No contextual Session is established. Use a MapperExtension that implements get_session or use 'import sqlalchemy.mods.threadlocal' to establish a default thread-local contextual session.")
+
def has_eager(self):
"""Return True if one of the properties attached to this
Mapper is eager loading.
for obj, connection in tups:
if not has_identity(obj):
for mapper in object_mapper(obj).iterate_to_root():
- mapper.extension.before_insert(mapper, connection, obj)
+ if 'before_insert' in mapper.extension.methods:
+ mapper.extension.before_insert(mapper, connection, obj)
else:
for mapper in object_mapper(obj).iterate_to_root():
- mapper.extension.before_update(mapper, connection, obj)
+ if 'before_update' in mapper.extension.methods:
+ mapper.extension.before_update(mapper, connection, obj)
for obj, connection in tups:
# detect if we have a "pending" instance (i.e. has no instance_key attached to it),
if not postupdate:
for obj, connection in inserted_objects:
for mapper in object_mapper(obj).iterate_to_root():
- mapper.extension.after_insert(mapper, connection, obj)
+ if 'after_insert' in mapper.extension.methods:
+ mapper.extension.after_insert(mapper, connection, obj)
for obj, connection in updated_objects:
for mapper in object_mapper(obj).iterate_to_root():
- mapper.extension.after_update(mapper, connection, obj)
+ if 'after_update' in mapper.extension.methods:
+ mapper.extension.after_update(mapper, connection, obj)
def _postfetch(self, connection, table, obj, resultproxy, params, value_params):
"""After an ``INSERT`` or ``UPDATE``, assemble newly generated
for (obj, connection) in tups:
for mapper in object_mapper(obj).iterate_to_root():
- mapper.extension.before_delete(mapper, connection, obj)
+ if 'before_delete' in mapper.extension.methods:
+ mapper.extension.before_delete(mapper, connection, obj)
deleted_objects = util.Set()
table_to_mapper = {}
for obj, connection in deleted_objects:
for mapper in object_mapper(obj).iterate_to_root():
- mapper.extension.after_delete(mapper, connection, obj)
+ if 'after_delete' in mapper.extension.methods:
+ mapper.extension.after_delete(mapper, connection, obj)
def _has_pks(self, table):
try:
else:
extension = self.extension
- ret = extension.translate_row(self, context, row)
- if ret is not EXT_CONTINUE:
- row = ret
+ if 'translate_row' in extension.methods:
+ ret = extension.translate_row(self, context, row)
+ if ret is not EXT_CONTINUE:
+ row = ret
if not skip_polymorphic and self.polymorphic_on is not None:
discriminator = row[self.polymorphic_on]
if identitykey not in local_identity_map:
local_identity_map[identitykey] = instance
isnew = True
- if extension.populate_instance(self, context, row, instance, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
+ if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
self.populate_instance(context, instance, row, instancekey=identitykey, isnew=isnew)
- if extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
+ if 'append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
if result is not None:
result.append(instance)
return instance
return None
# plugin point
- instance = extension.create_instance(self, context, row, self.class_)
- if instance is EXT_CONTINUE:
+ if 'create_instance' in extension.methods:
+ instance = extension.create_instance(self, context, row, self.class_)
+ if instance is EXT_CONTINUE:
+ instance = attribute_manager.new_instance(self.class_)
+ else:
instance = attribute_manager.new_instance(self.class_)
+
instance._entity_name = self.entity_name
if self.__should_log_debug:
self.__log_debug("_instance(): created new instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey)))
# call further mapper properties on the row, to pull further
# instances from the row and possibly populate this item.
flags = {'instancekey':identitykey, 'isnew':isnew}
- if extension.populate_instance(self, context, row, instance, **flags) is EXT_CONTINUE:
+ if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, **flags) is EXT_CONTINUE:
self.populate_instance(context, instance, row, **flags)
- if extension.append_result(self, context, row, instance, result, **flags) is EXT_CONTINUE:
+ if 'append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, **flags) is EXT_CONTINUE:
if result is not None:
result.append(instance)
from sqlalchemy.sql import expression, visitors
from sqlalchemy.orm import mapper, object_mapper
from sqlalchemy.orm import util as mapperutil
-from sqlalchemy.orm.interfaces import OperationContext, LoaderStack
+from sqlalchemy.orm.interfaces import LoaderStack
import operator
-__all__ = ['Query', 'QueryContext', 'SelectionContext']
+__all__ = ['Query', 'QueryContext']
class Query(object):
"""Encapsulates the object-fetching operations provided by Mappers."""
self._populate_existing = False
self._version_check = False
self._autoflush = True
+ self._eager_loaders = util.Set([x for x in self.mapper._eager_loaders])
+ self._attributes = {}
def _clone(self):
q = Query.__new__(Query)
"""
q = self._clone()
+ # most MapperOptions write to the '_attributes' dictionary,
+ # so copy that as well
+ q._attributes = q._attributes.copy()
opts = [o for o in util.flatten_iterator(args)]
q._with_options = q._with_options + opts
for opt in opts:
session = self.session
- kwargs.setdefault('populate_existing', self._populate_existing)
- kwargs.setdefault('version_check', self._version_check)
-
- context = SelectionContext(self.select_mapper, session, self._extension, with_options=self._with_options, **kwargs)
+ context = kwargs.pop('querycontext', None)
+ if context is None:
+ context = QueryContext(self)
process = []
mappers_or_columns = tuple(self._entities) + mappers_or_columns
except IndexError:
return None
- def _should_nest(self, querycontext):
- """Return True if the given statement options indicate that we
- should *nest* the generated query as a subquery inside of a
- larger eager-loading query. This is used with keywords like
- distinct, limit and offset and the mapper defines eager loads.
- """
-
- return (
- len(querycontext.eager_loaders) > 0
- and self._nestable(**querycontext.select_args())
- )
-
def _nestable(self, **kwargs):
"""Return true if the given statement options imply it should be nested."""
whereclause = self._criterion
context = QueryContext(self)
- from_obj = context.from_obj
+ from_obj = self._from_obj
alltables = []
for l in [sql_util.TableFinder(x) for x in from_obj]:
if self.table not in alltables:
from_obj.append(self.table)
- if self._nestable(**context.select_args()):
- s = sql.select([self.table], whereclause, from_obj=from_obj, **context.select_args()).alias('getcount').count()
+ if self._nestable(**self._select_args()):
+ s = sql.select([self.table], whereclause, from_obj=from_obj, **self._select_args()).alias('getcount').count()
else:
primary_key = self.primary_key_columns
- s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **context.select_args())
+ s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **self._select_args())
if self._autoflush and not self._populate_existing:
self.session._autoflush()
return self.session.scalar(s, params=self._params, mapper=self.mapper)
if isinstance(m, mapper.Mapper):
table = m.select_table
sql_util.ClauseAdapter(m.select_table).traverse(whereclause, stop_on=util.Set([m.select_table]))
+
+ from_obj = self._from_obj
- # get/create query context. get the ultimate compile arguments
- # from there
- order_by = context.order_by
- from_obj = context.from_obj
- lockmode = context.lockmode
+ order_by = self._order_by
if order_by is False:
order_by = self.mapper.order_by
if order_by is False:
order_by = self.table.default_order_by()
try:
- for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[lockmode]
+ for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[self._lockmode]
except KeyError:
- raise exceptions.ArgumentError("Unknown lockmode '%s'" % lockmode)
+ raise exceptions.ArgumentError("Unknown lockmode '%s'" % self._lockmode)
# if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so
# that we only load the appropriate types
if self.table not in alltables:
from_obj.append(self.table)
- if self._should_nest(context):
+ if self._eager_loaders and self._nestable(**self._select_args()):
# if theres an order by, add those columns to the column list
# of the "rowcount" query we're going to make
if order_by:
else:
cf = []
- s2 = sql.select(self.primary_key_columns + list(cf), whereclause, use_labels=True, from_obj=from_obj, correlate=False, **context.select_args())
+ s2 = sql.select(self.primary_key_columns + list(cf), whereclause, use_labels=True, from_obj=from_obj, correlate=False, **self._select_args())
if order_by:
s2 = s2.order_by(*util.to_list(order_by))
s3 = s2.alias('tbl_row_count')
if order_by:
statement.append_order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by))
else:
- statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **context.select_args())
+ statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **self._select_args())
if order_by:
statement.append_order_by(*util.to_list(order_by))
# to use it in "order_by". ensure they are in the column criterion (particularly oid).
# TODO: this should be done at the SQL level not the mapper level
# TODO: need test coverage for this
- if context.distinct and order_by:
+ if self._distinct and order_by:
[statement.append_column(c) for c in util.to_list(order_by)]
context.statement = statement
return context
+ def _select_args(self):
+ """Return a dictionary of attributes that can be applied to a ``sql.Select`` statement.
+ """
+ return {'limit':self._limit, 'offset':self._offset, 'distinct':self._distinct, 'group_by':self._group_by or None}
+
+
def _get_entity_clauses(self, m):
"""for tuples added via add_entity() or add_column(), attempt to locate
an AliasedClauses object which should be used to formulate the query as well
as to process result rows."""
+
(m, alias, alias_id) = m
if alias is not None:
return alias
Query.logger = logging.class_logger(Query)
-class QueryContext(OperationContext):
- """Created within the ``Query.compile()`` method to store and
- share state among all the Mappers and MapperProperty objects used
- in a query construction.
- """
-
+class QueryContext(object):
def __init__(self, query):
self.query = query
- self.order_by = query._order_by
- self.group_by = query._group_by
- self.from_obj = query._from_obj
- self.lockmode = query._lockmode
- self.distinct = query._distinct
- self.limit = query._limit
- self.offset = query._offset
- self.eager_loaders = util.Set([x for x in query.mapper._eager_loaders])
+ self.mapper = query.mapper
+ self.session = query.session
+ self.extension = query._extension
self.statement = None
- super(QueryContext, self).__init__(query.mapper, query._with_options)
-
- def select_args(self):
- """Return a dictionary of attributes from this
- ``QueryContext`` that can be applied to a ``sql.Select``
- statement.
- """
- return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct, 'group_by':self.group_by or None}
-
- def accept_option(self, opt):
- """Accept a ``MapperOption`` which will process (modify) the
- state of this ``QueryContext``.
- """
-
- opt.process_query_context(self)
-
-
-class SelectionContext(OperationContext):
- """Created within the ``query.instances()`` method to store and share
- state among all the Mappers and MapperProperty objects used in a
- load operation.
-
- SelectionContext contains these attributes:
-
- mapper
- The Mapper which originated the instances() call.
-
- session
- The Session that is relevant to the instances call.
-
- identity_map
- A dictionary which stores newly created instances that have not
- yet been added as persistent to the Session.
-
- attributes
- A dictionary to store arbitrary data; mappers, strategies, and
- options all store various state information here in order
- to communicate with each other and to themselves.
-
-
- populate_existing
- Indicates if its OK to overwrite the attributes of instances
- that were already in the Session.
-
- version_check
- Indicates if mappers that have version_id columns should verify
- that instances existing already within the Session should have
- this attribute compared to the freshly loaded value.
-
- querycontext
- the QueryContext, if any, used to generate the executed statement.
- If present, the attribute dictionary from this Context will be used
- as the basis for this SelectionContext's attribute dictionary. This
- allows query-compile-time operations to send messages to the
- result-processing-time operations.
- """
-
- def __init__(self, mapper, session, extension, **kwargs):
- self.populate_existing = kwargs.pop('populate_existing', False)
- self.version_check = kwargs.pop('version_check', False)
- querycontext = kwargs.pop('querycontext', None)
- if querycontext:
- kwargs['attributes'] = querycontext.attributes
- self.session = session
- self.extension = extension
+ self.populate_existing = query._populate_existing
+ self.version_check = query._version_check
self.identity_map = {}
self.stack = LoaderStack()
- super(SelectionContext, self).__init__(mapper, kwargs.pop('with_options', []), **kwargs)
-
- def accept_option(self, opt):
- """Accept a MapperOption which will process (modify) the state
- of this SelectionContext.
- """
- opt.process_selection_context(self)
+ self.options = query._with_options
+ self.attributes = query._attributes.copy()
+
+
class UndeferGroupOption(MapperOption):
def __init__(self, group):
self.group = group
- def process_query_context(self, context):
- context.attributes[('undefer', self.group)] = True
-
- def process_selection_context(self, context):
- context.attributes[('undefer', self.group)] = True
+ def process_query(self, query):
+ query._attributes[('undefer', self.group)] = True
class AbstractRelationLoader(LoaderStrategy):
def init(self):
def is_chained(self):
return not self.lazy and self.chained
- def process_query_property(self, context, properties):
+ def process_query_property(self, query, properties):
if self.lazy:
- if properties[-1] in context.eager_loaders:
- context.eager_loaders.remove(properties[-1])
+ if properties[-1] in query._eager_loaders:
+ query._eager_loaders = query._eager_loaders.difference(util.Set([properties[-1]]))
else:
- for prop in properties:
- context.eager_loaders.add(prop)
- super(EagerLazyOption, self).process_query_property(context, properties)
+ query._eager_loaders = query._eager_loaders.union(util.Set(properties))
+ super(EagerLazyOption, self).process_query_property(query, properties)
def get_strategy_class(self):
if self.lazy:
raise exceptions.ArgumentError("Fetchmode must be one of 'join' or 'select'")
self.type = type
- def process_selection_property(self, context, properties):
- context.attributes[('fetchmode', properties[-1])] = self.type
+ def process_query_property(self, query, properties):
+ query.attributes[('fetchmode', properties[-1])] = self.type
class RowDecorateOption(PropertyOption):
def __init__(self, key, decorator=None, alias=None):
self.decorator = decorator
self.alias = alias
- def process_selection_property(self, context, properties):
+ def process_query_property(self, query, properties):
if self.alias is not None and self.decorator is None:
if isinstance(self.alias, basestring):
self.alias = properties[-1].target.alias(self.alias)
d[c] = row[self.alias.corresponding_column(c)]
return d
self.decorator = decorate
- context.attributes[("eager_row_processor", properties[-1])] = self.decorator
+ query._attributes[("eager_row_processor", properties[-1])] = self.decorator
RowDecorateOption.logger = logging.class_logger(RowDecorateOption)
from sqlalchemy import sql, util, exceptions
from sqlalchemy.sql import util as sql_util
from sqlalchemy.sql import visitors
-from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE
+from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, build_path
all_cascades = util.Set(["delete", "delete-orphan", "all", "merge",
"expunge", "save-update", "refresh-expire", "none"])
def setdefault(self, col, value):
return super(TranslatingDict, self).setdefault(self.__translate_col(col), value)
-class ExtensionCarrier(MapperExtension):
+class ExtensionCarrier(object):
+ """stores a collection of MapperExtension objects.
+
+ allows an extension methods to be called on contained MapperExtensions
+ in the order they were added to this object. Also includes a 'methods' dictionary
+ accessor which allows for a quick check if a particular method
+ is overridden on any contained MapperExtensions.
+ """
+
def __init__(self, _elements=None):
self.__elements = _elements or []
-
+ self.methods = {}
+
def copy(self):
return ExtensionCarrier(list(self.__elements))
def insert(self, extension):
"""Insert a MapperExtension at the beginning of this ExtensionCarrier's list."""
- self.__elements.insert(0, extension)
+ self.__elements.insert(0, self.__inspect(extension))
def append(self, extension):
"""Append a MapperExtension at the end of this ExtensionCarrier's list."""
- self.__elements.append(extension)
+ self.__elements.append(self.__inspect(extension))
- def _create_do(funcname):
- def _do(self, *args, **kwargs):
+ def __inspect(self, extension):
+ for meth in MapperExtension.__dict__.keys():
+ if meth not in self.methods and hasattr(extension, meth) and getattr(extension, meth) is not getattr(MapperExtension, meth):
+ self.methods[meth] = self.__create_do(meth)
+ return extension
+
+ def __create_do(self, funcname):
+ def _do(*args, **kwargs):
for elem in self.__elements:
ret = getattr(elem, funcname)(*args, **kwargs)
if ret is not EXT_CONTINUE:
return ret
else:
return EXT_CONTINUE
- return _do
- instrument_class = _create_do('instrument_class')
- init_instance = _create_do('init_instance')
- init_failed = _create_do('init_failed')
- dispose_class = _create_do('dispose_class')
- get_session = _create_do('get_session')
- load = _create_do('load')
- get = _create_do('get')
- get_by = _create_do('get_by')
- select_by = _create_do('select_by')
- select = _create_do('select')
- translate_row = _create_do('translate_row')
- create_instance = _create_do('create_instance')
- append_result = _create_do('append_result')
- populate_instance = _create_do('populate_instance')
- before_insert = _create_do('before_insert')
- before_update = _create_do('before_update')
- after_update = _create_do('after_update')
- after_insert = _create_do('after_insert')
- before_delete = _create_do('before_delete')
- after_delete = _create_do('after_delete')
+ try:
+ _do.__name__ = funcname
+ except:
+ # cant set __name__ in py 2.3
+ pass
+ return _do
+
+ def _pass(self, *args, **kwargs):
+ return EXT_CONTINUE
+
+ def __getattr__(self, key):
+ return self.methods.get(key, self._pass)
class BinaryVisitor(visitors.ClauseVisitor):
def __init__(self, func):
self.parentclauses = parentclauses
if parentclauses is not None:
- self.path = parentclauses.path + (prop.parent, prop.key)
+ self.path = build_path(prop.parent, prop.key, parentclauses.path)
else:
- self.path = (prop.parent, prop.key)
+ self.path = build_path(prop.parent, prop.key)
self.prop = prop