From f893bb0b513e8e27403da55da0d68e3b15a13629 Mon Sep 17 00:00:00 2001 From: Jason Kirtland Date: Mon, 18 Aug 2008 18:09:27 +0000 Subject: [PATCH] more ORM @decorator fliparoo --- lib/sqlalchemy/orm/attributes.py | 31 ++-- lib/sqlalchemy/orm/collections.py | 40 ++--- lib/sqlalchemy/orm/interfaces.py | 8 +- lib/sqlalchemy/orm/mapper.py | 16 +- lib/sqlalchemy/orm/properties.py | 11 +- lib/sqlalchemy/orm/query.py | 255 +++++++++++++++--------------- lib/sqlalchemy/orm/session.py | 61 +++---- lib/sqlalchemy/orm/unitofwork.py | 196 ++++++++++++----------- lib/sqlalchemy/orm/util.py | 3 +- 9 files changed, 314 insertions(+), 307 deletions(-) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 7a001f18d1..2c1da53118 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -80,7 +80,7 @@ class QueryableAttribute(interfaces.PropComparator): def __init__(self, impl, comparator=None, parententity=None): """Construct an InstrumentedAttribute. - + comparator a sql.Comparator to which class-level compare/math events will be sent """ @@ -158,11 +158,11 @@ def proxied_attribute_factory(descriptor): self._parententity = parententity self.impl = _ProxyImpl(key) + @property def comparator(self): if callable(self._comparator): self._comparator = self._comparator() return self._comparator - comparator = property(comparator) def __get__(self, instance, owner): """Delegate __get__ to the original descriptor.""" @@ -399,9 +399,10 @@ class ScalarAttributeImpl(AttributeImpl): for ext in self.extensions: ext.remove(state, value, initiator or self) + @property def type(self): self.property.columns[0].type - type = property(type) + class MutableScalarAttributeImpl(ScalarAttributeImpl): """represents a scalar value-holding InstrumentedAttribute, which can detect @@ -480,7 +481,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): `initiator` is the ``InstrumentedAttribute`` that initiated the ``set()`` operation and is used to control the depth of a circular setter operation. - + """ if initiator is self: return @@ -752,7 +753,7 @@ class InstanceState(object): key = None runid = None expired_attributes = EMPTY_SET - + def __init__(self, obj, manager): self.class_ = obj.__class__ self.manager = manager @@ -817,7 +818,7 @@ class InstanceState(object): impl = self.get_impl(key) x = impl.get(self, passive=passive) if x is PASSIVE_NORESULT: - + return None elif hasattr(impl, 'get_collection'): return impl.get_collection(self, x, passive=passive) @@ -884,6 +885,7 @@ class InstanceState(object): del self.expired_attributes return ATTR_WAS_SET + @property def unmodified(self): """a set of keys which have no uncommitted changes""" @@ -893,21 +895,18 @@ class InstanceState(object): (key in self.manager.mutable_attributes and not self.manager[key].impl.check_mutable_modified(self)))) - unmodified = property(unmodified) - + @property def unloaded(self): """a set of keys which do not have a loaded value. - + This includes expired attributes and any other attribute that was never populated or modified. - + """ return set( key for key in self.manager.keys() if key not in self.committed_state and key not in self.dict) - unloaded = property(unloaded) - def expire_attributes(self, attribute_names): self.expired_attributes = set(self.expired_attributes) @@ -1158,13 +1157,13 @@ class ClassManager(dict): get_inst = dict.__getitem__ + @property def attributes(self): return self.itervalues() - attributes = property(attributes) + @classmethod def deferred_scalar_loader(cls, state, keys): """TODO""" - deferred_scalar_loader = classmethod(deferred_scalar_loader) ## InstanceState management @@ -1317,6 +1316,7 @@ class History(tuple): def __new__(cls, added, unchanged, deleted): return tuple.__new__(cls, (added, unchanged, deleted)) + @classmethod def from_attribute(cls, attribute, state, current): original = state.committed_state.get(attribute.key, NEVER_SET) @@ -1351,7 +1351,6 @@ class History(tuple): else: deleted = [] return cls([current], [], deleted) - from_attribute = classmethod(from_attribute) class PendingCollection(object): @@ -1434,7 +1433,7 @@ def register_attribute(class_, key, uselist, useobject, callable_=None, proxy_pr impl_class=impl_class, **kwargs), comparator=comparator, parententity=parententity) - + manager.instrument_attribute(key, descriptor) def unregister_attribute(class_, key): diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 0efd1d8674..f8570dd5fc 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -208,7 +208,8 @@ class collection(object): # Bundled as a class solely for ease of use: packaging, doc strings, # importability. - def appender(cls, fn): + @staticmethod + def appender(fn): """Tag the method as the collection appender. The appender method is called with one positional argument: the value @@ -250,9 +251,9 @@ class collection(object): """ setattr(fn, '_sa_instrument_role', 'appender') return fn - appender = classmethod(appender) - def remover(cls, fn): + @staticmethod + def remover(fn): """Tag the method as the collection remover. The remover method is called with one positional argument: the value @@ -277,9 +278,9 @@ class collection(object): """ setattr(fn, '_sa_instrument_role', 'remover') return fn - remover = classmethod(remover) - def iterator(cls, fn): + @staticmethod + def iterator(fn): """Tag the method as the collection remover. The iterator method is called with no arguments. It is expected to @@ -291,9 +292,9 @@ class collection(object): """ setattr(fn, '_sa_instrument_role', 'iterator') return fn - iterator = classmethod(iterator) - def internally_instrumented(cls, fn): + @staticmethod + def internally_instrumented(fn): """Tag the method as instrumented. This tag will prevent any decoration from being applied to the method. @@ -311,9 +312,9 @@ class collection(object): """ setattr(fn, '_sa_instrumented', True) return fn - internally_instrumented = classmethod(internally_instrumented) - def on_link(cls, fn): + @staticmethod + def on_link(fn): """Tag the method as a the "linked to attribute" event handler. This optional event handler will be called when the collection class @@ -325,9 +326,9 @@ class collection(object): """ setattr(fn, '_sa_instrument_role', 'on_link') return fn - on_link = classmethod(on_link) - def converter(cls, fn): + @staticmethod + def converter(fn): """Tag the method as the collection converter. This optional method will be called when a collection is being @@ -358,9 +359,9 @@ class collection(object): """ setattr(fn, '_sa_instrument_role', 'converter') return fn - converter = classmethod(converter) - def adds(cls, arg): + @staticmethod + def adds(arg): """Mark the method as adding an entity to the collection. Adds "add to collection" handling to the method. The decorator @@ -379,9 +380,9 @@ class collection(object): setattr(fn, '_sa_instrument_before', ('fire_append_event', arg)) return fn return decorator - adds = classmethod(adds) - def replaces(cls, arg): + @staticmethod + def replaces(arg): """Mark the method as replacing an entity in the collection. Adds "add to collection" and "remove from collection" handling to @@ -400,9 +401,9 @@ class collection(object): setattr(fn, '_sa_instrument_after', 'fire_remove_event') return fn return decorator - replaces = classmethod(replaces) - def removes(cls, arg): + @staticmethod + def removes(arg): """Mark the method as removing an entity in the collection. Adds "remove from collection" handling to the method. The decorator @@ -421,9 +422,9 @@ class collection(object): setattr(fn, '_sa_instrument_before', ('fire_remove_event', arg)) return fn return decorator - removes = classmethod(removes) - def removes_return(cls): + @staticmethod + def removes_return(): """Mark the method as removing an entity in the collection. Adds "remove from collection" handling to the method. The return value @@ -441,7 +442,6 @@ class collection(object): setattr(fn, '_sa_instrument_after', 'fire_remove_event') return fn return decorator - removes_return = classmethod(removes_return) # public instrumentation interface for 'internally instrumented' diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 0b60483a33..29bc980bc2 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -441,25 +441,25 @@ class PropComparator(expression.ColumnOperators): def __clause_element__(self): raise NotImplementedError("%r" % self) + @staticmethod def contains_op(a, b): return a.contains(b) - contains_op = staticmethod(contains_op) + @staticmethod def any_op(a, b, **kwargs): return a.any(b, **kwargs) - any_op = staticmethod(any_op) + @staticmethod def has_op(a, b, **kwargs): return a.has(b, **kwargs) - has_op = staticmethod(has_op) def __init__(self, prop, mapper): self.prop = self.property = prop self.mapper = mapper + @staticmethod def of_type_op(a, class_): return a.of_type(class_) - of_type_op = staticmethod(of_type_op) def of_type(self, class_): """Redefine this object in terms of a polymorphic subclass. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index cea61f28e9..3e5f418fa1 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -237,11 +237,11 @@ class Mapper(object): raise sa_exc.InvalidRequestError("Mapper '%s' has no property '%s'" % (str(self), key)) return prop + @property def iterate_properties(self): """return an iterator of all MapperProperty objects.""" self.compile() return self.__props.itervalues() - iterate_properties = property(iterate_properties) def __mappers_from_spec(self, spec, selectable): """given a with_polymorphic() argument, return the set of mappers it represents. @@ -282,12 +282,15 @@ class Mapper(object): return from_obj + @property + @util.cache_decorator def _with_polymorphic_mappers(self): if not self.with_polymorphic: return [self] return self.__mappers_from_spec(*self.with_polymorphic) - _with_polymorphic_mappers = property(util.cache_decorator(_with_polymorphic_mappers)) + @property + @util.cache_decorator def _with_polymorphic_selectable(self): if not self.with_polymorphic: return self.mapped_table @@ -297,7 +300,6 @@ class Mapper(object): return selectable else: return self.__selectable_from_mappers(self.__mappers_from_spec(spec, selectable)) - _with_polymorphic_selectable = property(util.cache_decorator(_with_polymorphic_selectable)) def _with_polymorphic_args(self, spec=None, selectable=False): if self.with_polymorphic: @@ -319,9 +321,9 @@ class Mapper(object): chain(*[list(mapper.iterate_properties) for mapper in [self] + mappers]) )) + @property def properties(self): raise NotImplementedError("Public collection of MapperProperty objects is provided by the get_property() and iterate_properties accessors.") - properties = property(properties) def dispose(self): # Disable any attribute-based compilation. @@ -557,6 +559,8 @@ class Mapper(object): self.primary_key = primary_key self.__log("Identified primary key columns: " + str(primary_key)) + @property + @util.cache_decorator def _get_clause(self): """create a "get clause" based on the primary key. this is used by query.get() and many-to-one lazyloads to load this item @@ -565,8 +569,9 @@ class Mapper(object): """ params = [(primary_key, sql.bindparam(None, type_=primary_key.type)) for primary_key in self.primary_key] return sql.and_(*[k==v for (k, v) in params]), dict(params) - _get_clause = property(util.cache_decorator(_get_clause)) + @property + @util.cache_decorator def _equivalent_columns(self): """Create a map of all *equivalent* columns, based on the determination of column pairs that are equated to @@ -604,7 +609,6 @@ class Mapper(object): visitors.traverse(mapper.inherit_condition, {}, {'binary':visit_binary}) return result - _equivalent_columns = property(util.cache_decorator(_equivalent_columns)) class _CompileOnAttr(PropComparator): """A placeholder descriptor which triggers compilation on access.""" diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index ebe61b19e2..3d717cae08 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -310,9 +310,9 @@ class PropertyLoader(StrategizedProperty): if of_type: self._of_type = _class_to_mapper(of_type) + @property def parententity(self): return self.prop.parent - parententity = property(parententity) def __clause_element__(self): return self.prop.parent._with_polymorphic_selectable @@ -325,10 +325,7 @@ class PropertyLoader(StrategizedProperty): def of_type(self, cls): return PropertyLoader.Comparator(self.prop, self.mapper, cls) - - def in_(self, other): - raise NotImplementedError("in_() not yet supported for relations. For a simple many-to-one, use in_() against the set of foreign key values.") - + def __eq__(self, other): if other is None: if self.prop.direction in [ONETOMANY, MANYTOMANY]: @@ -530,9 +527,7 @@ class PropertyLoader(StrategizedProperty): self.mapper = mapper.class_mapper(self.argument(), compile=False) else: raise sa_exc.ArgumentError("relation '%s' expects a class or a mapper argument (received: %s)" % (self.key, type(self.argument))) - - # TODO: an informative assertion ? - assert isinstance(self.mapper, mapper.Mapper) + assert isinstance(self.mapper, mapper.Mapper), self.mapper # accept callables for other attributes which may require deferred initialization for attr in ('order_by', 'primaryjoin', 'secondaryjoin', 'secondary', '_foreign_keys', 'remote_side'): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index ed3c2f7c38..784d30c8d6 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -101,7 +101,7 @@ class Query(object): self._mapper_adapter_map = d = self._mapper_adapter_map.copy() else: self._mapper_adapter_map = d = {} - + for ent in entities: for entity in ent.entities: if entity not in d: @@ -118,7 +118,7 @@ class Query(object): d[entity] = (mapper, adapter, selectable, is_aliased_class, with_polymorphic) ent.setup_entity(entity, *d[entity]) - + def __mapper_loads_polymorphically_with(self, mapper, adapter): for m2 in mapper._with_polymorphic_mappers: for m in m2.iterate_to_root(): @@ -161,7 +161,7 @@ class Query(object): alias = self._polymorphic_adapters.get(search, None) if alias: return alias.adapt_clause(element) - + def __replace_element(self, adapters): def replace(elem): if '_halt_adapt' in elem._annotations: @@ -172,7 +172,7 @@ class Query(object): if e: return e return replace - + def __replace_orm_element(self, adapters): def replace(elem): if '_halt_adapt' in elem._annotations: @@ -190,7 +190,7 @@ class Query(object): self._disable_orm_filtering = True def _adapt_clause(self, clause, as_filter, orm_only): - adapters = [] + adapters = [] if as_filter and self._filter_aliases: adapters.append(self._filter_aliases.replace) @@ -202,12 +202,12 @@ class Query(object): if not adapters: return clause - + if getattr(self, '_disable_orm_filtering', not orm_only): return visitors.replacement_traverse(clause, {'column_collections':False}, self.__replace_element(adapters)) else: return visitors.replacement_traverse(clause, {'column_collections':False}, self.__replace_orm_element(adapters)) - + def _entity_zero(self): return self._entities[0] @@ -218,6 +218,7 @@ class Query(object): ent = self._entity_zero() return getattr(ent, 'extension', ent.mapper.extension) + @property def _mapper_entities(self): # TODO: this is wrong, its hardcoded to "priamry entity" when # for the case of __all_equivs() it should not be @@ -225,7 +226,6 @@ class Query(object): for ent in self._entities: if hasattr(ent, 'primary_entity'): yield ent - _mapper_entities = property(_mapper_entities) def _joinpoint_zero(self): return self._joinpoint or self._entity_zero().entity_zero @@ -310,35 +310,36 @@ class Query(object): q.__dict__ = self.__dict__.copy() return q + @property def statement(self): - """return the full SELECT statement represented by this Query.""" + """The full SELECT statement represented by this Query.""" return self._compile_context(labels=self._with_labels).statement._annotate({'_halt_adapt': True}) - statement = property(statement) def subquery(self): """return the full SELECT statement represented by this Query, embedded within an Alias.""" - + return self.statement.alias() - + @_generative() def with_labels(self): """Apply column labels to the return value of Query.statement. - Indicates that this Query's `statement` accessor should return a - SELECT statement that applies labels to all columns in the form - _; this is commonly used to disambiguate - columns from multiple tables which have the same name. + Indicates that this Query's `statement` accessor should return + a SELECT statement that applies labels to all columns in the + form _; this is commonly used to + disambiguate columns from multiple tables which have the same + name. - When the `Query` actually issues SQL to load rows, it always uses - column labeling. + When the `Query` actually issues SQL to load rows, it always + uses column labeling. """ self._with_labels = True + @property def whereclause(self): - """return the WHERE criterion for this Query.""" + """The WHERE criterion for this Query.""" return self._criterion - whereclause = property(whereclause) @_generative() def _with_current_path(self, path): @@ -703,7 +704,7 @@ class Query(object): session.query(Person).join((Palias, Person.friends)) # join from Houses to the "rooms" attribute on the - # "Colonials" subclass of Houses, then join to the + # "Colonials" subclass of Houses, then join to the # "closets" relation on Room session.query(Houses).join(Colonials.rooms, Room.closets) @@ -731,7 +732,7 @@ class Query(object): approach to this. from_joinpoint - when joins are specified using string property names, - locate the property from the mapper found in the most recent previous + locate the property from the mapper found in the most recent previous join() call, instead of from the root entity. """ @@ -744,7 +745,7 @@ class Query(object): def outerjoin(self, *props, **kwargs): """Create a left outer join against this ``Query`` object's criterion and apply generatively, retunring the newly resulting ``Query``. - + Usage is the same as the ``join()`` method. """ @@ -770,18 +771,18 @@ class Query(object): alias_criterion = False left_entity = right_entity right_entity = right_mapper = None - + if isinstance(arg1, tuple): arg1, arg2 = arg1 else: arg2 = None - + if isinstance(arg2, (interfaces.PropComparator, basestring)): onclause = arg2 right_entity = arg1 elif isinstance(arg1, (interfaces.PropComparator, basestring)): onclause = arg1 - right_entity = arg2 + right_entity = arg2 else: onclause = arg2 right_entity = arg1 @@ -790,22 +791,22 @@ class Query(object): of_type = getattr(onclause, '_of_type', None) prop = onclause.property descriptor = onclause - + if not left_entity: left_entity = onclause.parententity - + if of_type: right_mapper = of_type else: right_mapper = prop.mapper - + if not right_entity: right_entity = right_mapper - + elif isinstance(onclause, basestring): if not left_entity: left_entity = self._joinpoint_zero() - + descriptor, prop = _entity_descriptor(left_entity, onclause) right_mapper = prop.mapper if not right_entity: @@ -816,7 +817,7 @@ class Query(object): else: if not left_entity: left_entity = self._joinpoint_zero() - + if not clause: if isinstance(onclause, interfaces.PropComparator): clause = onclause.__clause_element__() @@ -830,7 +831,7 @@ class Query(object): raise sa_exc.InvalidRequestError("Could not find a FROM clause to join from") mp, right_selectable, is_aliased_class = _entity_info(right_entity) - + if not right_mapper and mp: right_mapper = mp @@ -854,7 +855,7 @@ class Query(object): right_entity = aliased(right_mapper) alias_criterion = True aliased_entity = True - + elif prop: if prop.table in self.__currenttables: if prop.secondary is not None and prop.secondary not in self.__currenttables: @@ -871,19 +872,19 @@ class Query(object): right_entity = prop.mapper if alias_criterion: - right_adapter = ORMAdapter(right_entity, + right_adapter = ORMAdapter(right_entity, equivalents=right_mapper._equivalent_columns, chain_to=self._filter_aliases) - + if isinstance(onclause, sql.ClauseElement): onclause = right_adapter.traverse(onclause) if prop: onclause = prop - + clause = orm_join(clause, right_entity, onclause, isouter=outerjoin) - if alias_criterion: + if alias_criterion: self._filter_aliases = right_adapter - + if aliased_entity: self.__mapper_loads_polymorphically_with(right_mapper, ORMAdapter(right_entity, equivalents=right_mapper._equivalent_columns)) @@ -1045,7 +1046,7 @@ class Query(object): def _execute_and_instances(self, querycontext): result = self.session.execute(querycontext.statement, params=self._params, mapper=self._mapper_zero_or_none(), _state=self._refresh_state) return self.instances(result, querycontext) - + def instances(self, cursor, __context=None): """Given a ResultProxy cursor as returned by connection.execute(), return an ORM result as an iterator. @@ -1084,7 +1085,7 @@ class Query(object): if label) rowtuple = type.__new__(type, "RowTuple", (tuple,), labels) rowtuple.keys = labels.keys - + while True: context.progress = set() context.partials = {} @@ -1095,7 +1096,7 @@ class Query(object): break else: fetch = cursor.fetchall() - + if custom_rows: rows = [] for row in fetch: @@ -1159,7 +1160,7 @@ class Query(object): _get_clause = q._adapt_clause(_get_clause, True, False) q._criterion = _get_clause - + for i, primary_key in enumerate(mapper.primary_key): try: params[_get_params[primary_key].key] = ident[i] @@ -1170,9 +1171,9 @@ class Query(object): if lockmode is not None: q._lockmode = lockmode q.__get_options( - populate_existing=bool(refresh_state), - version_check=(lockmode is not None), - only_load_props=only_load_props, + populate_existing=bool(refresh_state), + version_check=(lockmode is not None), + only_load_props=only_load_props, refresh_state=refresh_state) q._order_by = None try: @@ -1181,33 +1182,35 @@ class Query(object): except IndexError: return None + @property def _select_args(self): return { - 'limit':self._limit, - 'offset':self._offset, - 'distinct':self._distinct, - 'group_by':self._group_by or None, + 'limit':self._limit, + 'offset':self._offset, + 'distinct':self._distinct, + 'group_by':self._group_by or None, 'having':self._having or None } - _select_args = property(_select_args) + @property def _should_nest_selectable(self): kwargs = self._select_args - return (kwargs.get('limit') is not None or kwargs.get('offset') is not None or kwargs.get('distinct', False)) - _should_nest_selectable = property(_should_nest_selectable) + return (kwargs.get('limit') is not None or + kwargs.get('offset') is not None or + kwargs.get('distinct', False)) def count(self): """Apply this query's criterion to a SELECT COUNT statement.""" - + return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self._only_mapper_zero().primary_key)) def _col_aggregate(self, col, func, nested_cols=None): context = QueryContext(self) self._adjust_for_single_inheritance(context) - + whereclause = context.whereclause - + from_obj = self.__mapper_zero_from_obj() if self._should_nest_selectable: @@ -1222,7 +1225,7 @@ class Query(object): if self._autoflush and not self._populate_existing: self.session._autoflush() return self.session.scalar(s, params=self._params, mapper=self._mapper_zero()) - + def delete(self, synchronize_session='evaluate'): """EXPERIMENTAL""" #TODO: lots of duplication and ifs - probably needs to be refactored to strategies @@ -1230,30 +1233,30 @@ class Query(object): if len(context.statement.froms) != 1 or not isinstance(context.statement.froms[0], schema.Table): raise sa_exc.ArgumentError("Only deletion via a single table query is currently supported") primary_table = context.statement.froms[0] - + session = self.session - + if synchronize_session == 'evaluate': try: evaluator_compiler = evaluator.EvaluatorCompiler() eval_condition = evaluator_compiler.process(self.whereclause) except evaluator.UnevaluatableError: synchronize_session = 'fetch' - + delete_stmt = sql.delete(primary_table, context.whereclause) - + if synchronize_session == 'fetch': #TODO: use RETURNING when available select_stmt = context.statement.with_only_columns(primary_table.primary_key) matched_rows = session.execute(select_stmt).fetchall() - + if self._autoflush: session._autoflush() session.execute(delete_stmt) - + if synchronize_session == 'evaluate': target_cls = self._mapper_zero().class_ - + #TODO: detect when the where clause is a trivial primary key match objs_to_expunge = [obj for (cls, pk),obj in session.identity_map.iteritems() if issubclass(cls, target_cls) and eval_condition(obj)] @@ -1268,66 +1271,66 @@ class Query(object): def update(self, values, synchronize_session='evaluate'): """EXPERIMENTAL""" - + #TODO: value keys need to be mapped to corresponding sql cols and instr.attr.s to string keys #TODO: updates of manytoone relations need to be converted to fk assignments - + context = self._compile_context() if len(context.statement.froms) != 1 or not isinstance(context.statement.froms[0], schema.Table): raise sa_exc.ArgumentError("Only update via a single table query is currently supported") primary_table = context.statement.froms[0] - + session = self.session - + if synchronize_session == 'evaluate': try: evaluator_compiler = evaluator.EvaluatorCompiler() eval_condition = evaluator_compiler.process(self.whereclause) - + value_evaluators = {} for key,value in values.items(): value_evaluators[key] = evaluator_compiler.process(expression._literal_as_binds(value)) except evaluator.UnevaluatableError: synchronize_session = 'expire' - + update_stmt = sql.update(primary_table, context.whereclause, values) - + if synchronize_session == 'expire': select_stmt = context.statement.with_only_columns(primary_table.primary_key) matched_rows = session.execute(select_stmt).fetchall() - + if self._autoflush: session._autoflush() session.execute(update_stmt) - + if synchronize_session == 'evaluate': target_cls = self._mapper_zero().class_ - + for (cls, pk),obj in session.identity_map.iteritems(): evaluated_keys = value_evaluators.keys() - + if issubclass(cls, target_cls) and eval_condition(obj): state = attributes.instance_state(obj) - + # only evaluate unmodified attributes to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: state.dict[key] = value_evaluators[key](obj) - + state.commit(list(to_evaluate)) - + # expire attributes with pending changes (there was no autoflush, so they are overwritten) state.expire_attributes(set(evaluated_keys).difference(to_evaluate)) - + elif synchronize_session == 'expire': target_mapper = self._mapper_zero() - + for primary_key in matched_rows: identity_key = target_mapper.identity_key_from_primary_key(list(primary_key)) if identity_key in session.identity_map: session.expire(session.identity_map[identity_key], values.keys()) - - + + def _compile_context(self, labels=True): context = QueryContext(self) @@ -1354,16 +1357,16 @@ class Query(object): froms = [context.from_clause] # "load from a single FROM" mode, i.e. when select_from() or join() is used else: froms = context.froms # "load from discrete FROMs" mode, i.e. when each _MappedEntity has its own FROM - + self._adjust_for_single_inheritance(context) - + if not context.primary_columns: if self._only_load_props: raise sa_exc.InvalidRequestError("No column-based properties specified for refresh operation." " Use session.expire() to reload collections and related items.") else: raise sa_exc.InvalidRequestError("Query contains no columns with which to SELECT from.") - + if eager_joins and self._should_nest_selectable: # for eager joins present and LIMIT/OFFSET/DISTINCT, wrap the query inside a select, # then append eager joins onto that @@ -1375,12 +1378,12 @@ class Query(object): order_by_col_expr = [] inner = sql.select( - context.primary_columns + order_by_col_expr, - context.whereclause, - from_obj=froms, - use_labels=labels, - correlate=False, - order_by=context.order_by, + context.primary_columns + order_by_col_expr, + context.whereclause, + from_obj=froms, + use_labels=labels, + correlate=False, + order_by=context.order_by, **self._select_args ) @@ -1394,10 +1397,10 @@ class Query(object): context.adapter = sql_util.ColumnAdapter(inner, equivs) statement = sql.select([inner] + context.secondary_columns, for_update=for_update, use_labels=labels) - + from_clause = inner for eager_join in eager_joins: - # EagerLoader places a 'stop_on' attribute on the join, + # EagerLoader places a 'stop_on' attribute on the join, # giving us a marker as to where the "splice point" of the join should be from_clause = sql_util.splice_joins(from_clause, eager_join, eager_join.stop_on) @@ -1418,24 +1421,24 @@ class Query(object): froms += context.eager_joins.values() statement = sql.select( - context.primary_columns + context.secondary_columns, - context.whereclause, - from_obj=froms, - use_labels=labels, - for_update=for_update, - correlate=False, - order_by=context.order_by, + context.primary_columns + context.secondary_columns, + context.whereclause, + from_obj=froms, + use_labels=labels, + for_update=for_update, + correlate=False, + order_by=context.order_by, **self._select_args ) - + if self._correlate: statement = statement.correlate(*self._correlate) if context.eager_order_by: statement.append_order_by(*context.eager_order_by) - + context.statement = statement - + return context def _adjust_for_single_inheritance(self, context): @@ -1490,7 +1493,7 @@ class _MapperEntity(_QueryEntity): self.entities = [entity] self.entity_zero = entity - + def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic): self.mapper = mapper self.extension = self.mapper.extension @@ -1508,7 +1511,7 @@ class _MapperEntity(_QueryEntity): if cls_or_mappers is None: query._reset_polymorphic_adapter(self.mapper) return - + mappers, from_obj = self.mapper._with_polymorphic_args(cls_or_mappers, selectable) self._with_polymorphic = mappers @@ -1560,46 +1563,46 @@ class _MapperEntity(_QueryEntity): adapter = sql_util.ColumnAdapter(self.selectable, self.mapper._equivalent_columns) if self.primary_entity: - _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter, + _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter, extension=self.extension, only_load_props=query._only_load_props, refresh_state=context.refresh_state ) else: _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter) - + if custom_rows: def main(context, row, result): _instance(row, result) else: def main(context, row): return _instance(row, None) - + if self.is_aliased_class: entname = self.entity._sa_label_name else: entname = self.mapper.class_.__name__ - + return main, entname def setup_context(self, query, context): adapter = self._get_entity_clauses(query, context) - + context.froms.append(self.selectable) if context.order_by is False and self.mapper.order_by: context.order_by = self.mapper.order_by - + if context.order_by and adapter: context.order_by = adapter.adapt_list(util.to_list(context.order_by)) - + for value in self.mapper._iterate_polymorphic_properties(self._with_polymorphic): if query._only_load_props and value.key not in query._only_load_props: continue value.setup( - context, - self, - (self.path_entity,), - adapter, - only_load_props=query._only_load_props, + context, + self, + (self.path_entity,), + adapter, + only_load_props=query._only_load_props, column_collection=context.primary_columns ) @@ -1615,7 +1618,7 @@ class _ColumnEntity(_QueryEntity): for c in column.c: _ColumnEntity(query, c) return - + query._entities.append(self) if isinstance(column, basestring): @@ -1628,15 +1631,15 @@ class _ColumnEntity(_QueryEntity): raise sa_exc.InvalidRequestError("Invalid column expression '%r'" % column) else: self._result_label = getattr(column, 'key', None) - + if not hasattr(column, '_label'): column = column.label(None) self.column = column self.froms = set() - + # look for ORM entities represented within the - # given expression. Try to count only entities + # given expression. Try to count only entities # for columns whos FROM object is in the actual list # of FROMs for the overall expression - this helps # subqueries which were built from ORM constructs from @@ -1649,12 +1652,12 @@ class _ColumnEntity(_QueryEntity): if 'parententity' in elem._annotations and actual_froms.intersection(elem._get_from_objects()) ) - + if self.entities: self.entity_zero = list(self.entities)[0] else: self.entity_zero = None - + def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic): self.selectable = from_obj self.froms.add(from_obj) @@ -1679,7 +1682,7 @@ class _ColumnEntity(_QueryEntity): def proc(context, row): return row[column] - + return (proc, self._result_label) def setup_context(self, query, context): @@ -1707,7 +1710,7 @@ class QueryContext(object): self.order_by = query._order_by if self.order_by: self.order_by = [expression._literal_as_text(o) for o in util.to_list(self.order_by)] - + self.query = query self.session = query.session self.populate_existing = query._populate_existing @@ -1735,7 +1738,7 @@ class AliasOption(interfaces.MapperOption): else: alias = self.alias query._from_obj_alias = sql_util.ColumnAdapter(alias) - + _runid = 1L _id_lock = util.threading.Lock() diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index ca9be7ffb1..b1a1ebe16c 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -87,14 +87,14 @@ def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False, Defaults to ``True``. When ``True``, all instances will be fully expired after each ``commit()``, so that all attribute/object access subsequent to a completed transaction will load from the most recent database state. - + _enable_transaction_accounting Defaults to ``True``. A legacy-only flag which when ``False`` disables *all* 0.5-style object accounting on transaction boundaries, including auto-expiry of instances on rollback and commit, maintenance of the "new" and "deleted" lists upon rollback, and autoflush of pending changes upon begin(), all of which are interdependent. - + autoflush When ``True``, all query operations will issue a ``flush()`` call to this ``Session`` before proceeding. This is a convenience feature so @@ -156,7 +156,7 @@ def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False, query_cls Class which should be used to create new Query objects, as returned by the ``query()`` method. Defaults to [sqlalchemy.orm.query#Query]. - + weak_identity_map When set to the default value of ``False``, a weak-referencing map is used; instances which are not externally referenced will be garbage @@ -233,9 +233,9 @@ class SessionTransaction(object): if self.session._enable_transaction_accounting: self._take_snapshot() + @property def is_active(self): return self.session is not None and self._active - is_active = property(is_active) def _assert_is_active(self): self._assert_is_open() @@ -248,9 +248,9 @@ class SessionTransaction(object): if self.session is None: raise sa_exc.InvalidRequestError("The transaction is closed") + @property def _is_transaction_boundary(self): return self.nested or not self._parent - _is_transaction_boundary = property(_is_transaction_boundary) def connection(self, bindkey, **kwargs): self._assert_is_active() @@ -1137,11 +1137,11 @@ class Session(object): state = attributes.instance_state(instance) except exc.NO_STATE: raise exc.UnmappedInstanceError(instance) - + # grab the full cascade list first, since lazyloads/autoflush # may be triggered by this operation (delete cascade lazyloads by default) cascade_states = list(_cascade_state_iterator('delete', state)) - self._delete_impl(state) + self._delete_impl(state) for state, m, o in cascade_states: self._delete_impl(state, ignore_transient=True) @@ -1463,61 +1463,62 @@ class Session(object): if added or deleted: return True return False - + + @property def is_active(self): - """return True if this Session has an active transaction.""" - + """True if this Session has an active transaction.""" + return self.transaction and self.transaction.is_active - is_active = property(is_active) - + + @property def _dirty_states(self): - """Return a set of all persistent states considered dirty. + """The set of all persistent states considered dirty. - This method returns all states that were modified including those that - were possibly deleted. + This method returns all states that were modified including + those that were possibly deleted. """ return util.IdentitySet( [state for state in self.identity_map.all_states() if state.check_modified()]) - _dirty_states = property(_dirty_states) + @property def dirty(self): - """Return a set of all persistent instances considered dirty. + """The set of all persistent instances considered dirty. Instances are considered dirty when they were modified but not deleted. Note that this 'dirty' calculation is 'optimistic'; most - attribute-setting or collection modification operations will mark an - instance as 'dirty' and place it in this set, even if there is no net - change to the attribute's value. At flush time, the value of each - attribute is compared to its previously saved value, and if there's no - net change, no SQL operation will occur (this is a more expensive - operation so it's only done at flush time). + attribute-setting or collection modification operations will + mark an instance as 'dirty' and place it in this set, even if + there is no net change to the attribute's value. At flush + time, the value of each attribute is compared to its + previously saved value, and if there's no net change, no SQL + operation will occur (this is a more expensive operation so + it's only done at flush time). - To check if an instance has actionable net changes to its attributes, - use the is_modified() method. + To check if an instance has actionable net changes to its + attributes, use the is_modified() method. """ return util.IdentitySet( [state.obj() for state in self._dirty_states if state not in self._deleted]) - dirty = property(dirty) + @property def deleted(self): - "Return a set of all instances marked as 'deleted' within this ``Session``" + "The set of all instances marked as 'deleted' within this ``Session``" return util.IdentitySet(self._deleted.values()) - deleted = property(deleted) + @property def new(self): - "Return a set of all instances marked as 'new' within this ``Session``." + "The set of all instances marked as 'new' within this ``Session``." return util.IdentitySet(self._new.values()) - new = property(new) def _expire_state(state, attribute_names): """Stand-alone expire instance function. diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 4116b7a015..2b5b8ae1f3 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -38,7 +38,7 @@ class UOWEventHandler(interfaces.AttributeExtension): def __init__(self, key): self.key = key - + def append(self, state, item, initiator): # process "save_update" cascade rules for when an instance is appended to the list of another instance sess = _state_session(state) @@ -74,7 +74,7 @@ def register_attribute(class_, key, *args, **kwargs): """overrides attributes.register_attribute() to add UOW event handlers to new InstrumentedAttributes. """ - + useobject = kwargs.get('useobject', False) if useobject: # for object-holding attributes, instrument UOWEventHandler @@ -83,14 +83,14 @@ def register_attribute(class_, key, *args, **kwargs): extension.insert(0, UOWEventHandler(key)) kwargs['extension'] = extension return attributes.register_attribute(class_, key, *args, **kwargs) - + class UOWTransaction(object): """Handles the details of organizing and executing transaction - tasks during a UnitOfWork object's flush() operation. - + tasks during a UnitOfWork object's flush() operation. + The central operation is to form a graph of nodes represented by the ``UOWTask`` class, which is then traversed by a ``UOWExecutor`` object that issues SQL and instance-synchronizing operations via the related @@ -100,16 +100,16 @@ class UOWTransaction(object): def __init__(self, session): self.session = session self.mapper_flush_opts = session._mapper_flush_opts - + # stores tuples of mapper/dependent mapper pairs, # representing a partial ordering fed into topological sort self.dependencies = set() - + # dictionary of mappers to UOWTasks self.tasks = {} - + # dictionary used by external actors to store arbitrary state - # information. + # information. self.attributes = {} self.logger = log.instance_logger(self, echoflag=session.echo_uow) @@ -118,7 +118,7 @@ class UOWTransaction(object): hashkey = ("history", state, key) # cache the objects, not the states; the strong reference here - # prevents newly loaded objects from being dereferenced during the + # prevents newly loaded objects from being dereferenced during the # flush process if hashkey in self.attributes: (added, unchanged, deleted, cached_passive) = self.attributes[hashkey] @@ -160,7 +160,7 @@ class UOWTransaction(object): def set_row_switch(self, state): """mark a deleted object as a 'row switch'. - + this indicates that an INSERT statement elsewhere corresponds to this DELETE; the INSERT is converted to an UPDATE and the DELETE does not occur. """ @@ -168,18 +168,18 @@ class UOWTransaction(object): task = self.get_task_by_mapper(mapper) taskelement = task._objects[state] taskelement.isdelete = "rowswitch" - + def is_deleted(self, state): """return true if the given state is marked as deleted within this UOWTransaction.""" - + mapper = _state_mapper(state) task = self.get_task_by_mapper(mapper) return task.is_deleted(state) - + def get_task_by_mapper(self, mapper, dontcreate=False): """return UOWTask element corresponding to the given mapper. - Will create a new UOWTask, including a UOWTask corresponding to the + Will create a new UOWTask, including a UOWTask corresponding to the "base" inherited mapper, if needed, unless the dontcreate flag is True. """ try: @@ -187,7 +187,7 @@ class UOWTransaction(object): except KeyError: if dontcreate: return None - + base_mapper = mapper.base_mapper if base_mapper in self.tasks: base_task = self.tasks[base_mapper] @@ -200,7 +200,7 @@ class UOWTransaction(object): mapper._register_dependencies(self) else: task = self.tasks[mapper] - + return task def register_dependency(self, mapper, dependency): @@ -208,7 +208,7 @@ class UOWTransaction(object): Called by ``mapper.PropertyLoader`` to register the objects handled by one mapper being dependent on the objects handled - by another. + by another. """ # correct for primary mapper @@ -223,7 +223,7 @@ class UOWTransaction(object): def register_processor(self, mapper, processor, mapperfrom): """register a dependency processor, corresponding to dependencies between the two given mappers. - + """ # correct for primary mapper @@ -234,14 +234,14 @@ class UOWTransaction(object): targettask = self.get_task_by_mapper(mapperfrom) up = UOWDependencyProcessor(processor, targettask) task.dependencies.add(up) - + def execute(self): """Execute this UOWTransaction. - + This will organize all collected UOWTasks into a dependency-sorted list which is then traversed using the traversal scheme encoded in the UOWExecutor class. Operations to mappers and dependency - processors are fired off in order to issue SQL to the database and + processors are fired off in order to issue SQL to the database and synchronize instance attributes with database values and related foreign key values.""" @@ -271,18 +271,18 @@ class UOWTransaction(object): import uowdumper uowdumper.UOWDumper(tasks, buf) return buf.getvalue() - + + @property def elements(self): - """return an iterator of all UOWTaskElements within this UOWTransaction.""" + """An iterator of all UOWTaskElements within this UOWTransaction.""" for task in self.tasks.values(): for elem in task.elements: yield elem - elements = property(elements) - + def finalize_flush_changes(self): """mark processed objects as clean / deleted after a successful flush(). - - this method is called within the flush() method after the + + this method is called within the flush() method after the execute() method has succeeded and the transaction has been committed. """ @@ -293,7 +293,7 @@ class UOWTransaction(object): self.session._register_newly_persistent(elem.state) def _sort_dependencies(self): - nodes = topological.sort_with_cycles(self.dependencies, + nodes = topological.sort_with_cycles(self.dependencies, [t.mapper for t in self.tasks.values() if t.base_task is t] ) @@ -315,8 +315,8 @@ class UOWTransaction(object): class UOWTask(object): """Represents all of the objects in the UOWTransaction which correspond to - a particular mapper. - + a particular mapper. + """ def __init__(self, uowtransaction, mapper, base_task=None): self.uowtransaction = uowtransaction @@ -333,42 +333,46 @@ class UOWTask(object): else: self.base_task = base_task base_task._inheriting_tasks[mapper] = self - + # the Mapper which this UOWTask corresponds to self.mapper = mapper # mapping of InstanceState -> UOWTaskElement - self._objects = {} + self._objects = {} self.dependent_tasks = [] self.dependencies = set() self.cyclical_dependencies = set() def polymorphic_tasks(self): - """return an iterator of UOWTask objects corresponding to the inheritance sequence - of this UOWTask's mapper. - - e.g. if mapper B and mapper C inherit from mapper A, and mapper D inherits from B: - - mapperA -> mapperB -> mapperD - -> mapperC - - the inheritance sequence starting at mapper A is a depth-first traversal: - - [mapperA, mapperB, mapperD, mapperC] - - this method will therefore return - - [UOWTask(mapperA), UOWTask(mapperB), UOWTask(mapperD), UOWTask(mapperC)] - - The concept of "polymporphic iteration" is adapted into several property-based - iterators which return object instances, UOWTaskElements and UOWDependencyProcessors - in an order corresponding to this sequence of parent UOWTasks. This is used to issue - operations related to inheritance-chains of mappers in the proper order based on - dependencies between those mappers. - + """Return an iterator of UOWTask objects corresponding to the + inheritance sequence of this UOWTask's mapper. + + e.g. if mapper B and mapper C inherit from mapper A, and + mapper D inherits from B: + + mapperA -> mapperB -> mapperD + -> mapperC + + the inheritance sequence starting at mapper A is a depth-first + traversal: + + [mapperA, mapperB, mapperD, mapperC] + + this method will therefore return + + [UOWTask(mapperA), UOWTask(mapperB), UOWTask(mapperD), + UOWTask(mapperC)] + + The concept of "polymporphic iteration" is adapted into + several property-based iterators which return object + instances, UOWTaskElements and UOWDependencyProcessors in an + order corresponding to this sequence of parent UOWTasks. This + is used to issue operations related to inheritance-chains of + mappers in the proper order based on dependencies between + those mappers. + """ - for mapper in self.mapper.polymorphic_iterator(): t = self.base_task._inheriting_tasks.get(mapper, None) if t is not None: @@ -381,18 +385,18 @@ class UOWTask(object): """ return not self._objects and not self.dependencies - + def append(self, state, listonly=False, isdelete=False): if state not in self._objects: self._objects[state] = rec = UOWTaskElement(state) else: rec = self._objects[state] - + rec.update(listonly, isdelete) - + def append_postupdate(self, state, post_update_cols): - """issue a 'post update' UPDATE statement via this object's mapper immediately. - + """issue a 'post update' UPDATE statement via this object's mapper immediately. + this operation is used only with relations that specify the `post_update=True` flag. """ @@ -404,7 +408,7 @@ class UOWTask(object): def __contains__(self, state): """return True if the given object is contained within this UOWTask or inheriting tasks.""" - + for task in self.polymorphic_tasks(): if state in task._objects: return True @@ -413,7 +417,7 @@ class UOWTask(object): def is_deleted(self, state): """return True if the given object is marked as to be deleted within this UOWTask.""" - + try: return self._objects[state].isdelete except KeyError: @@ -422,49 +426,49 @@ class UOWTask(object): def _polymorphic_collection(callable): """return a property that will adapt the collection returned by the given callable into a polymorphic traversal.""" - + def collection(self): for task in self.polymorphic_tasks(): for rec in callable(task): yield rec return property(collection) - + def _elements(self): return self._objects.values() + elements = property(_elements) - polymorphic_elements = _polymorphic_collection(_elements) + @property def polymorphic_tosave_elements(self): return [rec for rec in self.polymorphic_elements if not rec.isdelete] - polymorphic_tosave_elements = property(polymorphic_tosave_elements) - + + @property def polymorphic_todelete_elements(self): return [rec for rec in self.polymorphic_elements if rec.isdelete] - polymorphic_todelete_elements = property(polymorphic_todelete_elements) + @property def polymorphic_tosave_objects(self): return [ rec.state for rec in self.polymorphic_elements if rec.state is not None and not rec.listonly and rec.isdelete is False ] - polymorphic_tosave_objects = property(polymorphic_tosave_objects) + @property def polymorphic_todelete_objects(self): return [ rec.state for rec in self.polymorphic_elements if rec.state is not None and not rec.listonly and rec.isdelete is True ] - polymorphic_todelete_objects = property(polymorphic_todelete_objects) + @_polymorphic_collection def polymorphic_dependencies(self): return self.dependencies - polymorphic_dependencies = _polymorphic_collection(polymorphic_dependencies) - + + @_polymorphic_collection def polymorphic_cyclical_dependencies(self): return self.cyclical_dependencies - polymorphic_cyclical_dependencies = _polymorphic_collection(polymorphic_cyclical_dependencies) - + def _sort_circular_dependencies(self, trans, cycles): """Create a hierarchical tree of *subtasks* which associate specific dependency actions with individual @@ -531,7 +535,7 @@ class UOWTask(object): (added, unchanged, deleted) = dep.get_object_dependencies(state, trans, passive=True) if not added and not unchanged and not deleted: continue - + # the task corresponding to saving/deleting of those dependent objects childtask = trans.get_task_by_mapper(processor.mapper) @@ -568,7 +572,7 @@ class UOWTask(object): get_dependency_task(state, dep).append(state, isdelete=isdelete) head = topological.sort_as_tree(tuples, allobjects) - + used_tasks = set() def make_task_tree(node, parenttask, nexttasks): (state, cycles, children) = node @@ -590,11 +594,11 @@ class UOWTask(object): return t t = UOWTask(self.uowtransaction, self.mapper) - + # stick the non-circular dependencies onto the new UOWTask for d in extradeplist: t.dependencies.add(d) - + if head is not None: make_task_tree(head, t, {}) @@ -610,7 +614,7 @@ class UOWTask(object): for dep in t2.dependencies: localtask.dependencies.add(dep) ret.insert(0, localtask) - + return ret def __repr__(self): @@ -618,9 +622,9 @@ class UOWTask(object): class UOWTaskElement(object): """Corresponds to a single InstanceState to be saved, deleted, - or otherwise marked as having dependencies. A collection of + or otherwise marked as having dependencies. A collection of UOWTaskElements are held by a UOWTask. - + """ def __init__(self, state): self.state = state @@ -645,7 +649,7 @@ class UOWTaskElement(object): each processor as marked as "processed" when complete, however changes to the state of this UOWTaskElement will reset - the list of completed processors, so that they + the list of completed processors, so that they execute again, until no new objects or state changes are brought in. """ @@ -663,7 +667,7 @@ class UOWDependencyProcessor(object): dependent data, such as filling in a foreign key on a child item from a new primary key, or deleting association rows before a delete. This object acts as a proxy to a DependencyProcessor. - + """ def __init__(self, processor, targettask): self.processor = processor @@ -671,10 +675,10 @@ class UOWDependencyProcessor(object): def __repr__(self): return "UOWDependencyProcessor(%s, %s)" % (str(self.processor), str(self.targettask)) - + def __str__(self): return repr(self) - + def __eq__(self, other): return other.processor is self.processor and other.targettask is self.targettask @@ -687,8 +691,8 @@ class UOWDependencyProcessor(object): This may locate additional objects which should be part of the transaction, such as those affected deletes, orphans to be deleted, etc. - - Once an object is preprocessed, its ``UOWTaskElement`` is marked as processed. If subsequent + + Once an object is preprocessed, its ``UOWTaskElement`` is marked as processed. If subsequent changes occur to the ``UOWTaskElement``, its processed flag is reset, and will require processing again. @@ -715,7 +719,7 @@ class UOWDependencyProcessor(object): def execute(self, trans, delete): """process all objects contained within this ``UOWDependencyProcessor``s target task.""" - + if not delete: self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_tosave_elements], trans, delete=False) else: @@ -725,25 +729,25 @@ class UOWDependencyProcessor(object): return trans.get_attribute_history(state, self.processor.key, passive=passive) def whose_dependent_on_who(self, state1, state2): - """establish which object is operationally dependent amongst a parent/child + """establish which object is operationally dependent amongst a parent/child using the semantics stated by the dependency processor. - + This method is used to establish a partial ordering (set of dependency tuples) when toplogically sorting on a per-instance basis. - + """ return self.processor.whose_dependent_on_who(state1, state2) def branch(self, task): """create a copy of this ``UOWDependencyProcessor`` against a new ``UOWTask`` object. - + this is used within the instance-level sorting operation when a single ``UOWTask`` is broken up into many individual ``UOWTask`` objects. - + """ return UOWDependencyProcessor(self.processor, task) - - + + class UOWExecutor(object): """Encapsulates the execution traversal of a UOWTransaction structure.""" diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 4915d930d6..b9abd0b79a 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -12,6 +12,7 @@ from sqlalchemy.sql import expression, util as sql_util, operators from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator, MapperProperty from sqlalchemy.orm import attributes, exc + all_cascades = frozenset(("delete", "delete-orphan", "all", "merge", "expunge", "save-update", "refresh-expire", "none")) @@ -209,9 +210,9 @@ class ExtensionCarrier(object): pass return _do + @staticmethod def _pass(*args, **kwargs): return EXT_CONTINUE - _pass = staticmethod(_pass) def __getattr__(self, key): """Delegate MapperExtension methods to bundled fronts.""" -- 2.47.3