From: Mike Bayer Date: Sat, 8 Dec 2007 18:58:03 +0000 (+0000) Subject: - flush() refactor merged from uow_nontree branch r3871-r3885 X-Git-Tag: rel_0_4_2~90 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=8693d4b2876e9239cf48bbc42a7ffaa11c01b506;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - flush() refactor merged from uow_nontree branch r3871-r3885 - topological.py cleaned up, presents three public facing functions which return list/tuple based structures, without exposing any internals. only the third function returns the "hierarchical" structure. when results include "cycles" or "child" items, 2- or 3- tuples are used to represent results. - unitofwork uses InstanceState almost exclusively now. new and deleted lists are now dicts which ref the actual object to provide a strong ref for the duration that they're in those lists. IdentitySet is only used for the public facing versions of "new" and "deleted". - unitofwork topological sort no longer uses the "hierarchical" version of the sort for the base sort, only for the "per-object" secondary sort where it still helps to group non-dependent operations together and provides expected insert order. the default sort deals with UOWTasks in a straight list and is greatly simplified. Tests all pass but need to see if svilen's stuff still works, one block of code in _sort_cyclical_dependencies() seems to not be needed anywhere but i definitely put it there for a reason at some point; if not hopefully we can derive more test coverage from that. - the UOWEventHandler is only applied to object-storing attributes, not scalar (i.e. column-based) ones. cuts out a ton of overhead when setting non-object based attributes. - InstanceState also used throughout the flush process, i.e. dependency.py, mapper.save_obj()/delete_obj(), sync.execute() all expect InstanceState objects in most cases now. - mapper/property cascade_iterator() takes InstanceState as its argument, but still returns lists of object instances so that they are not dereferenced. - a few tricks needed when dealing with InstanceState, i.e. when loading a list of items that are possibly fresh from the DB, you *have* to get the actual objects into a strong-referencing datastructure else they fall out of scope immediately. dependency.py caches lists of dependent objects which it loads now (i.e. history collections). - AttributeHistory is gone, replaced by a function that returns a 3-tuple of added, unchanged, deleted. these collections still reference the object instances directly for the strong-referencing reasons mentiontioned, but it uses less IdentitySet logic to generate. --- diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 45e4c036f7..259909d479 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -320,6 +320,10 @@ class ExecutionContext(object): returns_rows True if the statement should return result rows + + postfetch_cols + a list of Column objects for which a server-side default + or inline SQL expression value was fired off. applies to inserts and updates. The Dialect should provide an ExecutionContext via the create_execution_context() method. The `pre_exec` and `post_exec` @@ -414,11 +418,6 @@ class ExecutionContext(object): raise NotImplementedError() - def postfetch_cols(self): - """return a list of Column objects for which a 'passive' server-side default - value was fired off. applies to inserts and updates.""" - - raise NotImplementedError() class Compiled(object): """Represent a compiled SQL expression. @@ -1481,7 +1480,7 @@ class ResultProxy(object): See ExecutionContext for details. """ - return self.context.postfetch_cols() + return self.context.postfetch_cols def supports_sane_rowcount(self): """Return ``supports_sane_rowcount`` from the dialect. diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 0e50093ee6..38ea903e43 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -308,11 +308,8 @@ class DefaultExecutionContext(base.ExecutionContext): return self._last_updated_params def lastrow_has_defaults(self): - return hasattr(self, '_postfetch_cols') and len(self._postfetch_cols) + return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols) - def postfetch_cols(self): - return self._postfetch_cols - def set_input_sizes(self): """Given a cursor and ClauseParameters, call the appropriate style of ``setinputsizes()`` on the cursor, using DB-API types @@ -383,4 +380,4 @@ class DefaultExecutionContext(base.ExecutionContext): else: self._last_updated_params = compiled_parameters - self._postfetch_cols = self.compiled.postfetch + self.postfetch_cols = self.compiled.postfetch diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index bb7085402d..b73ec0e00f 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -181,8 +181,9 @@ class AttributeImpl(object): def get_history(self, state, passive=False): current = self.get(state, passive=passive) if current is PASSIVE_NORESULT: - return None - return AttributeHistory(self, state, current) + return (None, None, None) + else: + return _create_history(self, state, current) def set_callable(self, state, callable_, clear=False): """Set a callable function for this attribute on the given object. @@ -326,8 +327,8 @@ class ScalarAttributeImpl(AttributeImpl): def check_mutable_modified(self, state): if self.mutable_scalars: - h = self.get_history(state, passive=True) - if h is not None and h.is_modified(): + (added, unchanged, deleted) = self.get_history(state, passive=True) + if added or deleted: state.modified = True return True else: @@ -568,7 +569,6 @@ class CollectionAttributeImpl(AttributeImpl): collections.CollectionAdapter(self, state, user_data) return getattr(user_data, '_sa_adapter') - class GenericBackrefExtension(interfaces.AttributeExtension): """An extension which synchronizes a two-way relationship. @@ -612,14 +612,12 @@ class ClassState(object): class InstanceState(object): """tracks state information at the instance level.""" - __slots__ = 'class_', 'obj', 'dict', 'pending', 'committed_state', 'modified', 'trigger', 'callables', 'parents', 'instance_dict', '_strong_obj', 'expired_attributes' - def __init__(self, obj): self.class_ = obj.__class__ self.obj = weakref.ref(obj, self.__cleanup) self.dict = obj.__dict__ self.committed_state = {} - self.modified = False + self.modified = self.strong = False self.trigger = None self.callables = {} self.parents = {} @@ -681,7 +679,7 @@ class InstanceState(object): return False def __resurrect(self, instance_dict): - if self.is_modified(): + if self.strong or self.is_modified(): # store strong ref'ed version of the object; will revert # to weakref when changes are persisted obj = new_instance(self.class_, state=self) @@ -893,74 +891,46 @@ class WeakInstanceDict(UserDict.UserDict): class StrongInstanceDict(dict): def all_states(self): return [o._state for o in self.values()] - -class AttributeHistory(object): - """Calculate the *history* of a particular attribute on a - particular instance. - """ - def __init__(self, attr, state, current): - self.attr = attr +def _create_history(attr, state, current): + if state.committed_state: + original = state.committed_state.get(attr.key, NO_VALUE) + else: + original = NO_VALUE - # get the "original" value. if a lazy load was fired when we got - # the 'current' value, this "original" was also populated just - # now as well (therefore we have to get it second) - if state.committed_state: - original = state.committed_state.get(attr.key, NO_VALUE) + if hasattr(attr, 'get_collection'): + if original is NO_VALUE: + s = util.IdentitySet([]) else: - original = NO_VALUE - - if hasattr(attr, 'get_collection'): - self._current = current + s = util.IdentitySet(original) + + _added_items = [] + _unchanged_items = [] + _deleted_items = [] + if current: + collection = attr.get_collection(state, current) + for a in collection: + if a in s: + _unchanged_items.append(a) + else: + _added_items.append(a) + _deleted_items = list(s.difference(_unchanged_items)) - if original is NO_VALUE: - s = util.IdentitySet([]) - else: - s = util.IdentitySet(original) - - # FIXME: the tests have an assumption on the collection's ordering - self._added_items = util.OrderedIdentitySet() - self._unchanged_items = util.OrderedIdentitySet() - self._deleted_items = util.OrderedIdentitySet() - if current: - collection = attr.get_collection(state, current) - for a in collection: - if a in s: - self._unchanged_items.add(a) - else: - self._added_items.add(a) - for a in s: - if a not in self._unchanged_items: - self._deleted_items.add(a) + return (_added_items, _unchanged_items, _deleted_items) + else: + if attr.is_equal(current, original) is True: + _unchanged_items = [current] + _added_items = [] + _deleted_items = [] else: - self._current = [current] - if attr.is_equal(current, original) is True: - self._unchanged_items = [current] - self._added_items = [] - self._deleted_items = [] + _added_items = [current] + if original is not NO_VALUE and original is not None: + _deleted_items = [original] else: - self._added_items = [current] - if original is not NO_VALUE and original is not None: - self._deleted_items = [original] - else: - self._deleted_items = [] - self._unchanged_items = [] - - def __iter__(self): - return iter(self._current) - - def is_modified(self): - return len(self._deleted_items) > 0 or len(self._added_items) > 0 - - def added_items(self): - return list(self._added_items) - - def unchanged_items(self): - return list(self._unchanged_items) - - def deleted_items(self): - return list(self._deleted_items) - + _deleted_items = [] + _unchanged_items = [] + return (_added_items, _unchanged_items, _deleted_items) + class PendingCollection(object): """stores items appended and removed from a collection that has not been loaded yet. @@ -987,30 +957,25 @@ def _managed_attributes(class_): return chain(*[cl._class_state.attrs.values() for cl in class_.__mro__[:-1] if hasattr(cl, '_class_state')]) -def is_modified(instance): - return instance._state.is_modified() - -def get_history(instance, key, **kwargs): - return getattr(instance.__class__, key).impl.get_history(instance._state, **kwargs) - -def get_as_list(instance, key, passive=False): - """Return an attribute of the given name from the given instance. +def get_history(state, key, **kwargs): + return getattr(state.class_, key).impl.get_history(state, **kwargs) +get_state_history = get_history - If the attribute is a scalar, return it as a single-item list, - otherwise return a collection based attribute. - - If the attribute's value is to be produced by an unexecuted - callable, the callable will only be executed if the given - `passive` flag is False. +def get_as_list(state, key, passive=False): + """return an InstanceState attribute as a list, + regardless of it being a scalar or collection-based + attribute. + + returns None if passive=True and the getter returns + PASSIVE_NORESULT. """ - - attr = getattr(instance.__class__, key).impl - state = instance._state + + attr = getattr(state.class_, key).impl x = attr.get(state, passive=passive) if x is PASSIVE_NORESULT: - return [] + return None elif hasattr(attr, 'get_collection'): - return list(attr.get_collection(state, x)) + return attr.get_collection(state, x) elif isinstance(x, list): return x else: diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index 9220c5743b..ae499ce1ec 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -51,12 +51,12 @@ class DependencyProcessor(object): return getattr(self.parent.class_, self.key) - def hasparent(self, obj): + def hasparent(self, state): """return True if the given object instance has a parent, according to the ``InstrumentedAttribute`` handled by this ``DependencyProcessor``.""" # TODO: use correct API for this - return self._get_instrumented_attribute().impl.hasparent(obj._state) + return self._get_instrumented_attribute().impl.hasparent(state) def register_dependencies(self, uowcommit): """Tell a ``UOWTransaction`` what mappers are dependent on @@ -72,21 +72,18 @@ class DependencyProcessor(object): raise NotImplementedError() - def whose_dependent_on_who(self, obj1, obj2): + def whose_dependent_on_who(self, state1, state2): """Given an object pair assuming `obj2` is a child of `obj1`, return a tuple with the dependent object second, or None if - they are equal. - - Used by objectstore's object-level topological sort (i.e. cyclical - table dependency). + there is no dependency. """ - if obj1 is obj2: + if state1 is state2: return None elif self.direction == ONETOMANY: - return (obj1, obj2) + return (state1, state2) else: - return (obj2, obj1) + return (state2, state1) def process_dependencies(self, task, deplist, uowcommit, delete = False): """This method is called during a flush operation to @@ -108,13 +105,13 @@ class DependencyProcessor(object): raise NotImplementedError() - def _verify_canload(self, child): + def _verify_canload(self, state): if not self.enable_typechecks: return - if child is not None and not self.mapper._canload(child): - raise exceptions.FlushError("Attempting to flush an item of type %s on collection '%s', which is handled by mapper '%s' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ? Set 'enable_typechecks=False' on the relation() to disable this exception. Mismatched typeloading may cause bi-directional relationships (backrefs) to not function properly." % (child.__class__, self.prop, self.mapper)) + if state is not None and not self.mapper._canload(state): + raise exceptions.FlushError("Attempting to flush an item of type %s on collection '%s', which is handled by mapper '%s' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ? Set 'enable_typechecks=False' on the relation() to disable this exception. Mismatched typeloading may cause bi-directional relationships (backrefs) to not function properly." % (state.class_, self.prop, self.mapper)) - def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit): + def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): """Called during a flush to synchronize primary key identifier values between a parent/child object, as well as to an associationrow in the case of many-to-many. @@ -123,13 +120,10 @@ class DependencyProcessor(object): raise NotImplementedError() def _compile_synchronizers(self): - """Assemble a list of *synchronization rules*, which are - instructions on how to populate the objects on each side of a - relationship. This is done when a ``DependencyProcessor`` is - first initialized. - - The list of rules is used within commits by the ``_synchronize()`` - method when dependent objects are processed. + """Assemble a list of *synchronization rules*. + + These are fired to populate attributes from one side + of a relation to another. """ self.syncrules = sync.ClauseSynchronizer(self.parent, self.mapper, self.direction) @@ -139,15 +133,28 @@ class DependencyProcessor(object): else: self.syncrules.compile(self.prop.primaryjoin, foreign_keys=self.foreign_keys) - def get_object_dependencies(self, obj, uowcommit, passive = True): - """Return the list of objects that are dependent on the given - object, as according to the relationship this dependency - processor represents. - """ - - return attributes.get_history(obj, self.key, passive = passive) + def get_object_dependencies(self, state, uowcommit, passive = True): + key = ("dependencies", state, self.key, passive) + + # cache the objects, not the states; the strong reference here + # prevents newly loaded objects from being dereferenced during the + # flush process + if key in uowcommit.attributes: + (added, unchanged, deleted) = uowcommit.attributes[key] + else: + (added, unchanged, deleted) = attributes.get_history(state, self.key, passive = passive) + uowcommit.attributes[key] = (added, unchanged, deleted) + + if added is None: + return (added, unchanged, deleted) + else: + return ( + [getattr(c, '_state', None) for c in added], + [getattr(c, '_state', None) for c in unchanged], + [getattr(c, '_state', None) for c in deleted], + ) - def _conditional_post_update(self, obj, uowcommit, related): + def _conditional_post_update(self, state, uowcommit, related): """Execute a post_update call. For relations that contain the post_update flag, an additional @@ -161,10 +168,10 @@ class DependencyProcessor(object): given related object list contains ``INSERT``s or ``DELETE``s. """ - if obj is not None and self.post_update: + if state is not None and self.post_update: for x in related: if x is not None: - uowcommit.register_object(obj, postupdate=True, post_update_cols=self.syncrules.dest_columns()) + uowcommit.register_object(state, postupdate=True, post_update_cols=self.syncrules.dest_columns()) break def __str__(self): @@ -190,27 +197,28 @@ class OneToManyDP(DependencyProcessor): # this phase can be called safely for any cascade but is unnecessary if delete cascade # is on. if (not self.cascade.delete or self.post_update) and not self.passive_deletes=='all': - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes) - if childlist is not None: - for child in childlist.deleted_items(): + for state in deplist: + (added, unchanged, deleted) = self.get_object_dependencies(state, uowcommit, passive=self.passive_deletes) + if unchanged or deleted: + for child in deleted: if child is not None and self.hasparent(child) is False: - self._synchronize(obj, child, None, True, uowcommit) - self._conditional_post_update(child, uowcommit, [obj]) - for child in childlist.unchanged_items(): + self._synchronize(state, child, None, True, uowcommit) + self._conditional_post_update(child, uowcommit, [state]) + for child in unchanged: if child is not None: - self._synchronize(obj, child, None, True, uowcommit) - self._conditional_post_update(child, uowcommit, [obj]) + self._synchronize(state, child, None, True, uowcommit) + self._conditional_post_update(child, uowcommit, [state]) else: - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit, passive=True) - if childlist is not None: - for child in childlist.added_items(): - self._synchronize(obj, child, None, False, uowcommit) - self._conditional_post_update(child, uowcommit, [obj]) - for child in childlist.deleted_items(): + for state in deplist: + (added, unchanged, deleted) = self.get_object_dependencies(state, uowcommit, passive=True) + if added or deleted: + for child in added: + self._synchronize(state, child, None, False, uowcommit) + if child is not None: + self._conditional_post_update(child, uowcommit, [state]) + for child in deleted: if not self.cascade.delete_orphan and not self.hasparent(child): - self._synchronize(obj, child, None, True, uowcommit) + self._synchronize(state, child, None, True, uowcommit) def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): #print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " preprocess_dep isdelete " + repr(delete) + " direction " + repr(self.direction) @@ -219,37 +227,39 @@ class OneToManyDP(DependencyProcessor): # head object is being deleted, and we manage its list of child objects # the child objects have to have their foreign key to the parent set to NULL if not self.post_update and not self.cascade.delete and not self.passive_deletes=='all': - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes) - if childlist is not None: - for child in childlist.deleted_items(): + for state in deplist: + (added, unchanged, deleted) = self.get_object_dependencies(state, uowcommit, passive=self.passive_deletes) + if unchanged or deleted: + for child in deleted: if child is not None and self.hasparent(child) is False: uowcommit.register_object(child) - for child in childlist.unchanged_items(): + for child in unchanged: if child is not None: uowcommit.register_object(child) else: - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit, passive=True) - if childlist is not None: - for child in childlist.added_items(): + for state in deplist: + (added, unchanged, deleted) = self.get_object_dependencies(state, uowcommit, passive=True) + if added or deleted: + for child in added: if child is not None: uowcommit.register_object(child) - for child in childlist.deleted_items(): + for child in deleted: if not self.cascade.delete_orphan: uowcommit.register_object(child, isdelete=False) elif self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True) for c, m in self.mapper.cascade_iterator('delete', child): - uowcommit.register_object(c, isdelete=True) + uowcommit.register_object(c._state, isdelete=True) - def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit): - source = obj + def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): + if child is not None: + child = getattr(child, '_state', child) + source = state dest = child - if dest is None or (not self.post_update and uowcommit.is_deleted(dest)): + if dest is None or (not self.post_update and uowcommit.state_is_deleted(dest)): return self._verify_canload(child) - self.syncrules.execute(source, dest, obj, child, clearkeys) + self.syncrules.execute(source, dest, source, child, clearkeys) class ManyToOneDP(DependencyProcessor): def register_dependencies(self, uowcommit): @@ -269,18 +279,18 @@ class ManyToOneDP(DependencyProcessor): if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes=='all': # post_update means we have to update our row to not reference the child object # before we can DELETE the row - for obj in deplist: - self._synchronize(obj, None, None, True, uowcommit) - childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes) - if childlist is not None: - self._conditional_post_update(obj, uowcommit, childlist.deleted_items() + childlist.unchanged_items() + childlist.added_items()) + for state in deplist: + self._synchronize(state, None, None, True, uowcommit) + (added, unchanged, deleted) = self.get_object_dependencies(state, uowcommit, passive=self.passive_deletes) + if added or unchanged or deleted: + self._conditional_post_update(state, uowcommit, deleted + unchanged + added) else: - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit, passive=True) - if childlist is not None: - for child in childlist.added_items(): - self._synchronize(obj, child, None, False, uowcommit) - self._conditional_post_update(obj, uowcommit, childlist.deleted_items() + childlist.unchanged_items() + childlist.added_items()) + for state in deplist: + (added, unchanged, deleted) = self.get_object_dependencies(state, uowcommit, passive=True) + if added or deleted or unchanged: + for child in added: + self._synchronize(state, child, None, False, uowcommit) + self._conditional_post_update(state, uowcommit, deleted + unchanged + added) def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): #print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " PRE process_dep isdelete " + repr(delete) + " direction " + repr(self.direction) @@ -288,33 +298,33 @@ class ManyToOneDP(DependencyProcessor): return if delete: if self.cascade.delete: - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes) - if childlist is not None: - for child in childlist.deleted_items() + childlist.unchanged_items(): + for state in deplist: + (added, unchanged, deleted) = self.get_object_dependencies(state, uowcommit, passive=self.passive_deletes) + if deleted or unchanged: + for child in deleted + unchanged: if child is not None and self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True) for c, m in self.mapper.cascade_iterator('delete', child): - uowcommit.register_object(c, isdelete=True) + uowcommit.register_object(c._state, isdelete=True) else: - for obj in deplist: - uowcommit.register_object(obj) + for state in deplist: + uowcommit.register_object(state) if self.cascade.delete_orphan: - childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes) - if childlist is not None: - for child in childlist.deleted_items(): + (added, unchanged, deleted) = self.get_object_dependencies(state, uowcommit, passive=self.passive_deletes) + if deleted: + for child in deleted: if self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True) for c, m in self.mapper.cascade_iterator('delete', child): - uowcommit.register_object(c, isdelete=True) + uowcommit.register_object(c._state, isdelete=True) - def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit): + def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): source = child - dest = obj - if dest is None or (not self.post_update and uowcommit.is_deleted(dest)): + dest = state + if dest is None or (not self.post_update and uowcommit.state_is_deleted(dest)): return self._verify_canload(child) - self.syncrules.execute(source, dest, obj, child, clearkeys) + self.syncrules.execute(source, dest, dest, child, clearkeys) class ManyToManyDP(DependencyProcessor): def register_dependencies(self, uowcommit): @@ -341,34 +351,34 @@ class ManyToManyDP(DependencyProcessor): reverse_dep = None if delete: - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes) - if childlist is not None: - for child in childlist.deleted_items() + childlist.unchanged_items(): - if child is None or (reverse_dep and (reverse_dep, "manytomany", id(child), id(obj)) in uowcommit.attributes): + for state in deplist: + (added, unchanged, deleted) = self.get_object_dependencies(state, uowcommit, passive=self.passive_deletes) + if deleted or unchanged: + for child in deleted + unchanged: + if child is None or (reverse_dep and (reverse_dep, "manytomany", child, state) in uowcommit.attributes): continue associationrow = {} - self._synchronize(obj, child, associationrow, False, uowcommit) + self._synchronize(state, child, associationrow, False, uowcommit) secondary_delete.append(associationrow) - uowcommit.attributes[(self, "manytomany", id(obj), id(child))] = True + uowcommit.attributes[(self, "manytomany", state, child)] = True else: - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit) - if childlist is None: continue - for child in childlist.added_items(): - if child is None or (reverse_dep and (reverse_dep, "manytomany", id(child), id(obj)) in uowcommit.attributes): - continue - associationrow = {} - self._synchronize(obj, child, associationrow, False, uowcommit) - uowcommit.attributes[(self, "manytomany", id(obj), id(child))] = True - secondary_insert.append(associationrow) - for child in childlist.deleted_items(): - if child is None or (reverse_dep and (reverse_dep, "manytomany", id(child), id(obj)) in uowcommit.attributes): - continue - associationrow = {} - self._synchronize(obj, child, associationrow, False, uowcommit) - uowcommit.attributes[(self, "manytomany", id(obj), id(child))] = True - secondary_delete.append(associationrow) + for state in deplist: + (added, unchanged, deleted) = self.get_object_dependencies(state, uowcommit) + if added or deleted: + for child in added: + if child is None or (reverse_dep and (reverse_dep, "manytomany", child, state) in uowcommit.attributes): + continue + associationrow = {} + self._synchronize(state, child, associationrow, False, uowcommit) + uowcommit.attributes[(self, "manytomany", state, child)] = True + secondary_insert.append(associationrow) + for child in deleted: + if child is None or (reverse_dep and (reverse_dep, "manytomany", child, state) in uowcommit.attributes): + continue + associationrow = {} + self._synchronize(state, child, associationrow, False, uowcommit) + uowcommit.attributes[(self, "manytomany", state, child)] = True + secondary_delete.append(associationrow) if secondary_delete: secondary_delete.sort() @@ -385,22 +395,22 @@ class ManyToManyDP(DependencyProcessor): def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): #print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " preprocess_dep isdelete " + repr(delete) + " direction " + repr(self.direction) if not delete: - for obj in deplist: - childlist = self.get_object_dependencies(obj, uowcommit, passive=True) - if childlist is not None: - for child in childlist.deleted_items(): + for state in deplist: + (added, unchanged, deleted) = self.get_object_dependencies(state, uowcommit, passive=True) + if deleted: + for child in deleted: if self.cascade.delete_orphan and self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True) for c, m in self.mapper.cascade_iterator('delete', child): - uowcommit.register_object(c, isdelete=True) + uowcommit.register_object(c._state, isdelete=True) - def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit): + def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): dest = associationrow source = None if dest is None: return self._verify_canload(child) - self.syncrules.execute(source, dest, obj, child, clearkeys) + self.syncrules.execute(source, dest, state, child, clearkeys) class AssociationDP(OneToManyDP): def __init__(self, *args, **kwargs): @@ -413,7 +423,7 @@ class MapperStub(object): many-to-many join, when performing a ``flush()``. The ``Task`` objects in the objectstore module treat it just like - any other ``Mapper``, but in fact it only serves as a *dependency* + any other ``Mapper``, but in fact it only serves as a dependency placeholder for the many-to-many update task. """ diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index e55703c420..af1be28bc6 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -13,7 +13,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl): def get(self, state, passive=False): if passive: - return self.get_history(state, passive=True).added_items() + return self._get_collection(state, passive=True).added_items else: return AppenderQuery(self, state) @@ -23,7 +23,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl): state.dict[self.key] = CollectionHistory(self, state) def get_collection(self, state, user_data=None): - return self.get_history(state, passive=True)._added_items + return self._get_collection(state, passive=True).added_items def set(self, state, value, initiator): if initiator is self: @@ -38,6 +38,10 @@ class DynamicAttributeImpl(attributes.AttributeImpl): raise NotImplementedError() def get_history(self, state, passive=False): + c = self._get_collection(state, passive) + return (c.added_items, c.unchanged_items, c.deleted_items) + + def _get_collection(self, state, passive=False): try: c = state.dict[self.key] except KeyError: @@ -47,15 +51,15 @@ class DynamicAttributeImpl(attributes.AttributeImpl): return CollectionHistory(self, state, apply_to=c) else: return c - + def append(self, state, value, initiator, passive=False): if initiator is not self: - self.get_history(state, passive=True)._added_items.append(value) + self._get_collection(state, passive=True).added_items.append(value) self.fire_append_event(state, value, initiator) def remove(self, state, value, initiator, passive=False): if initiator is not self: - self.get_history(state, passive=True)._deleted_items.append(value) + self._get_collection(state, passive=True).deleted_items.append(value) self.fire_remove_event(state, value, initiator) @@ -82,21 +86,21 @@ class AppenderQuery(Query): def __iter__(self): sess = self.__session() if sess is None: - return iter(self.attr.get_history(self.state, passive=True)._added_items) + return iter(self.attr._get_collection(self.state, passive=True).added_items) else: return iter(self._clone(sess)) def __getitem__(self, index): sess = self.__session() if sess is None: - return self.attr.get_history(self.state, passive=True)._added_items.__getitem__(index) + return self.attr._get_collection(self.state, passive=True).added_items.__getitem__(index) else: return self._clone(sess).__getitem__(index) def count(self): sess = self.__session() if sess is None: - return len(self.attr.get_history(self.state, passive=True)._added_items) + return len(self.attr._get_collection(self.state, passive=True).added_items) else: return self._clone(sess).count() @@ -121,7 +125,7 @@ class AppenderQuery(Query): oldlist = list(self) else: oldlist = [] - self.attr.get_history(self.state, passive=True).replace(oldlist, collection) + self.attr._get_collection(self.state, passive=True).replace(oldlist, collection) return oldlist def append(self, item): @@ -131,35 +135,23 @@ class AppenderQuery(Query): self.attr.remove(self.state, item, None) -class CollectionHistory(attributes.AttributeHistory): +class CollectionHistory(object): """Overrides AttributeHistory to receive append/remove events directly.""" def __init__(self, attr, state, apply_to=None): if apply_to: - deleted = util.IdentitySet(apply_to._deleted_items) - added = apply_to._added_items + deleted = util.IdentitySet(apply_to.deleted_items) + added = apply_to.added_items coll = AppenderQuery(attr, state).autoflush(False) - self._unchanged_items = [o for o in util.IdentitySet(coll) if o not in deleted] - self._added_items = apply_to._added_items - self._deleted_items = apply_to._deleted_items + self.unchanged_items = [o for o in util.IdentitySet(coll) if o not in deleted] + self.added_items = apply_to.added_items + self.deleted_items = apply_to.deleted_items else: - self._deleted_items = [] - self._added_items = [] - self._unchanged_items = [] + self.deleted_items = [] + self.added_items = [] + self.unchanged_items = [] def replace(self, olditems, newitems): - self._added_items = newitems - self._deleted_items = olditems + self.added_items = newitems + self.deleted_items = olditems - def is_modified(self): - return len(self._deleted_items) > 0 or len(self._added_items) > 0 - - def added_items(self): - return self._added_items - - def unchanged_items(self): - return self._unchanged_items - - def deleted_items(self): - return self._deleted_items - diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index f5d2e65b4a..ae4711beb8 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -898,37 +898,45 @@ class Mapper(object): """ return self.identity_key_from_primary_key(self.primary_key_from_instance(instance)) + def _identity_key_from_state(self, state): + return self.identity_key_from_primary_key(self._primary_key_from_state(state)) + def primary_key_from_instance(self, instance): """Return the list of primary key values for the given instance. """ - return [self._get_attr_by_column(instance, column) for column in self.primary_key] + return [self._get_state_attr_by_column(instance._state, column) for column in self.primary_key] + + def _primary_key_from_state(self, state): + return [self._get_state_attr_by_column(state, column) for column in self.primary_key] - def _canload(self, instance): - """return true if this mapper is capable of loading the given instance""" + def _canload(self, state): if self.polymorphic_on is not None: - return isinstance(instance, self.class_) + return issubclass(state.class_, self.class_) else: - return instance.__class__ is self.class_ - - def _get_attr_by_column(self, obj, column): - """Return an instance attribute using a Column as the key.""" + return state.class_ is self.class_ + + def _get_state_attr_by_column(self, state, column): try: - return self._columntoproperty[column].getattr(obj, column) + return self._columntoproperty[column].getattr(state, column) except KeyError: prop = self.__props.get(column.key, None) if prop: raise exceptions.InvalidRequestError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop))) else: raise exceptions.InvalidRequestError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self))) + + def _set_state_attr_by_column(self, state, column, value): + return self._columntoproperty[column].setattr(state, value, column) + + def _get_attr_by_column(self, obj, column): + return self._get_state_attr_by_column(obj._state, column) def _set_attr_by_column(self, obj, column, value): - """Set the value of an instance attribute using a Column as the key.""" + self._set_state_attr_by_column(obj._state, column, value) - self._columntoproperty[column].setattr(obj, value, column) - - def save_obj(self, objects, uowtransaction, postupdate=False, post_update_cols=None, single=False): + def save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. This is called within the context of a UOWTransaction during a @@ -947,44 +955,44 @@ class Mapper(object): # if batch=false, call save_obj separately for each object if not single and not self.batch: - for obj in objects: - self.save_obj([obj], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True) + for state in states: + self.save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True) return if 'connection_callable' in uowtransaction.mapper_flush_opts: connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - tups = [(obj, connection_callable(self, obj)) for obj in objects] + tups = [(state, connection_callable(self, state.obj())) for state in states] else: connection = uowtransaction.transaction.connection(self) - tups = [(obj, connection) for obj in objects] + tups = [(state, connection) for state in states] if not postupdate: - for obj, connection in tups: - if not has_identity(obj): - for mapper in object_mapper(obj).iterate_to_root(): + for state, connection in tups: + if not _state_has_identity(state): + for mapper in _state_mapper(state).iterate_to_root(): if 'before_insert' in mapper.extension.methods: - mapper.extension.before_insert(mapper, connection, obj) + mapper.extension.before_insert(mapper, connection, state.obj()) else: - for mapper in object_mapper(obj).iterate_to_root(): + for mapper in _state_mapper(state).iterate_to_root(): if 'before_update' in mapper.extension.methods: - mapper.extension.before_update(mapper, connection, obj) + mapper.extension.before_update(mapper, connection, state.obj()) - for obj, connection in tups: + for state, connection in tups: # detect if we have a "pending" instance (i.e. has no instance_key attached to it), # and another instance with the same identity key already exists as persistent. convert to an # UPDATE if so. - mapper = object_mapper(obj) - instance_key = mapper.identity_key_from_instance(obj) - if not postupdate and not has_identity(obj) and instance_key in uowtransaction.uow.identity_map: + mapper = _state_mapper(state) + instance_key = mapper._identity_key_from_state(state) + if not postupdate and not _state_has_identity(state) and instance_key in uowtransaction.uow.identity_map: existing = uowtransaction.uow.identity_map[instance_key] if not uowtransaction.is_deleted(existing): - raise exceptions.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (mapperutil.instance_str(obj), str(instance_key), mapperutil.instance_str(existing))) + raise exceptions.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (mapperutil.state_str(state), str(instance_key), mapperutil.instance_str(existing))) if self.__should_log_debug: - self.__log_debug("detected row switch for identity %s. will update %s, remove %s from transaction" % (instance_key, mapperutil.instance_str(obj), mapperutil.instance_str(existing))) + self.__log_debug("detected row switch for identity %s. will update %s, remove %s from transaction" % (instance_key, mapperutil.state_str(state), mapperutil.instance_str(existing))) uowtransaction.set_row_switch(existing) - if has_identity(obj): - if obj._instance_key != instance_key: - raise exceptions.FlushError("Can't change the identity of instance %s in session (existing identity: %s; new identity: %s)" % (mapperutil.instance_str(obj), obj._instance_key, instance_key)) + if _state_has_identity(state): + if state.dict['_instance_key'] != instance_key: + raise exceptions.FlushError("Can't change the identity of instance %s in session (existing identity: %s; new identity: %s)" % (mapperutil.state_str(state), state.dict['_instance_key'], instance_key)) inserted_objects = util.Set() updated_objects = util.Set() @@ -999,17 +1007,17 @@ class Mapper(object): insert = [] update = [] - for obj, connection in tups: - mapper = object_mapper(obj) + for state, connection in tups: + mapper = _state_mapper(state) if table not in mapper._pks_by_table: continue pks = mapper._pks_by_table[table] - instance_key = mapper.identity_key_from_instance(obj) + instance_key = mapper._identity_key_from_state(state) if self.__should_log_debug: - self.__log_debug("save_obj() table '%s' instance %s identity %s" % (table.name, mapperutil.instance_str(obj), str(instance_key))) + self.__log_debug("save_obj() table '%s' instance %s identity %s" % (table.name, mapperutil.state_str(state), str(instance_key))) - isinsert = not instance_key in uowtransaction.uow.identity_map and not postupdate and not has_identity(obj) + isinsert = not instance_key in uowtransaction.uow.identity_map and not postupdate and not _state_has_identity(state) params = {} value_params = {} hasdata = False @@ -1019,7 +1027,7 @@ class Mapper(object): if col is mapper.version_id_col: params[col.key] = 1 elif col in pks: - value = mapper._get_attr_by_column(obj, col) + value = mapper._get_state_attr_by_column(state, col) if value is not None: params[col.key] = value elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col): @@ -1029,41 +1037,39 @@ class Mapper(object): if col.default is None or value is not None: params[col.key] = value else: - value = mapper._get_attr_by_column(obj, col) + value = mapper._get_state_attr_by_column(state, col) if col.default is None or value is not None: if isinstance(value, sql.ClauseElement): value_params[col] = value else: params[col.key] = value - insert.append((obj, params, mapper, connection, value_params)) + insert.append((state, params, mapper, connection, value_params)) else: for col in mapper._cols_by_table[table]: if col is mapper.version_id_col: - params[col._label] = mapper._get_attr_by_column(obj, col) + params[col._label] = mapper._get_state_attr_by_column(state, col) params[col.key] = params[col._label] + 1 for prop in mapper._columntoproperty.values(): - history = attributes.get_history(obj, prop.key, passive=True) - if history and history.added_items(): + (added, unchanged, deleted) = attributes.get_history(state, prop.key, passive=True) + if added: hasdata = True elif col in pks: - params[col._label] = mapper._get_attr_by_column(obj, col) + params[col._label] = mapper._get_state_attr_by_column(state, col) elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col): pass else: if post_update_cols is not None and col not in post_update_cols: continue prop = mapper._columntoproperty[col] - history = attributes.get_history(obj, prop.key, passive=True) - if history: - a = history.added_items() - if a: - if isinstance(a[0], sql.ClauseElement): - value_params[col] = a[0] - else: - params[col.key] = prop.get_col_value(col, a[0]) - hasdata = True + (added, unchanged, deleted) = attributes.get_history(state, prop.key, passive=True) + if added: + if isinstance(added[0], sql.ClauseElement): + value_params[col] = added[0] + else: + params[col.key] = prop.get_col_value(col, added[0]) + hasdata = True if hasdata: - update.append((obj, params, mapper, connection, value_params)) + update.append((state, params, mapper, connection, value_params)) if update: mapper = table_to_mapper[table] @@ -1084,12 +1090,12 @@ class Mapper(object): return 0 update.sort(comparator) for rec in update: - (obj, params, mapper, connection, value_params) = rec + (state, params, mapper, connection, value_params) = rec c = connection.execute(statement.values(value_params), params) - mapper._postfetch(connection, table, obj, c, c.last_updated_params(), value_params) + mapper._postfetch(connection, table, state, c, c.last_updated_params(), value_params) # testlib.pragma exempt:__hash__ - updated_objects.add((id(obj), obj, connection)) + updated_objects.add((state, connection)) rows += c.rowcount if c.supports_sane_rowcount() and rows != len(update): @@ -1098,49 +1104,49 @@ class Mapper(object): if insert: statement = table.insert() def comparator(a, b): - return cmp(a[0]._sa_insert_order, b[0]._sa_insert_order) + return cmp(a[0].insert_order, b[0].insert_order) insert.sort(comparator) for rec in insert: - (obj, params, mapper, connection, value_params) = rec + (state, params, mapper, connection, value_params) = rec c = connection.execute(statement.values(value_params), params) primary_key = c.last_inserted_ids() if primary_key is not None: i = 0 for col in mapper._pks_by_table[table]: - if mapper._get_attr_by_column(obj, col) is None and len(primary_key) > i: - mapper._set_attr_by_column(obj, col, primary_key[i]) + if mapper._get_state_attr_by_column(state, col) is None and len(primary_key) > i: + mapper._set_state_attr_by_column(state, col, primary_key[i]) i+=1 - mapper._postfetch(connection, table, obj, c, c.last_inserted_params(), value_params) + mapper._postfetch(connection, table, state, c, c.last_inserted_params(), value_params) # synchronize newly inserted ids from one table to the next # TODO: this fires off more than needed, try to organize syncrules # per table for m in util.reversed(list(mapper.iterate_to_root())): if m._synchronizer is not None: - m._synchronizer.execute(obj, obj) + m._synchronizer.execute(state, state) # testlib.pragma exempt:__hash__ - inserted_objects.add((id(obj), obj, connection)) + inserted_objects.add((state, connection)) if not postupdate: - for id_, obj, connection in inserted_objects: - for mapper in object_mapper(obj).iterate_to_root(): + for state, connection in inserted_objects: + for mapper in _state_mapper(state).iterate_to_root(): if 'after_insert' in mapper.extension.methods: - mapper.extension.after_insert(mapper, connection, obj) - for id_, obj, connection in updated_objects: - for mapper in object_mapper(obj).iterate_to_root(): + mapper.extension.after_insert(mapper, connection, state.obj()) + for state, connection in updated_objects: + for mapper in _state_mapper(state).iterate_to_root(): if 'after_update' in mapper.extension.methods: - mapper.extension.after_update(mapper, connection, obj) + mapper.extension.after_update(mapper, connection, state) - def _postfetch(self, connection, table, obj, resultproxy, params, value_params): + def _postfetch(self, connection, table, state, resultproxy, params, value_params): """After an ``INSERT`` or ``UPDATE``, assemble newly generated values on an instance. For columns which are marked as being generated on the database side, set up a group-based "deferred" loader which will populate those attributes in one query when next accessed. """ - postfetch_cols = resultproxy.postfetch_cols().union(util.Set(value_params.keys())) + postfetch_cols = util.Set(resultproxy.postfetch_cols()).union(util.Set(value_params.keys())) deferred_props = [] for c in self._cols_by_table[table]: @@ -1150,13 +1156,13 @@ class Mapper(object): continue if c.primary_key or not c.key in params: continue - if self._get_attr_by_column(obj, c) != params[c.key]: - self._set_attr_by_column(obj, c, params[c.key]) + if self._get_state_attr_by_column(state, c) != params[c.key]: + self._set_state_attr_by_column(state, c, params[c.key]) if deferred_props: - expire_instance(obj, deferred_props) + _expire_state(state, deferred_props) - def delete_obj(self, objects, uowtransaction): + def delete_obj(self, states, uowtransaction): """Issue ``DELETE`` statements for a list of objects. This is called within the context of a UOWTransaction during a @@ -1168,15 +1174,15 @@ class Mapper(object): if 'connection_callable' in uowtransaction.mapper_flush_opts: connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - tups = [(obj, connection_callable(self, obj)) for obj in objects] + tups = [(state, connection_callable(self, state.obj())) for state in states] else: connection = uowtransaction.transaction.connection(self) - tups = [(obj, connection) for obj in objects] + tups = [(state, connection) for state in states] - for (obj, connection) in tups: - for mapper in object_mapper(obj).iterate_to_root(): + for (state, connection) in tups: + for mapper in _state_mapper(state).iterate_to_root(): if 'before_delete' in mapper.extension.methods: - mapper.extension.before_delete(mapper, connection, obj) + mapper.extension.before_delete(mapper, connection, state.obj()) deleted_objects = util.Set() table_to_mapper = {} @@ -1186,22 +1192,22 @@ class Mapper(object): for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=True): delete = {} - for (obj, connection) in tups: - mapper = object_mapper(obj) + for (state, connection) in tups: + mapper = _state_mapper(state) if table not in mapper._pks_by_table: continue params = {} - if not hasattr(obj, '_instance_key'): + if not _state_has_identity(state): continue else: delete.setdefault(connection, []).append(params) for col in mapper._pks_by_table[table]: - params[col.key] = mapper._get_attr_by_column(obj, col) + params[col.key] = mapper._get_state_attr_by_column(state, col) if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): - params[mapper.version_id_col.key] = mapper._get_attr_by_column(obj, mapper.version_id_col) + params[mapper.version_id_col.key] = mapper._get_state_attr_by_column(state, mapper.version_id_col) # testlib.pragma exempt:__hash__ - deleted_objects.add((id(obj), obj, connection)) + deleted_objects.add((state, connection)) for connection, del_objects in delete.iteritems(): mapper = table_to_mapper[table] def comparator(a, b): @@ -1221,10 +1227,10 @@ class Mapper(object): if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects): raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects))) - for id_, obj, connection in deleted_objects: - for mapper in object_mapper(obj).iterate_to_root(): + for state, connection in deleted_objects: + for mapper in _state_mapper(state).iterate_to_root(): if 'after_delete' in mapper.extension.methods: - mapper.extension.after_delete(mapper, connection, obj) + mapper.extension.after_delete(mapper, connection, state.obj()) def register_dependencies(self, uowcommit, *args, **kwargs): """Register ``DependencyProcessor`` instances with a @@ -1237,7 +1243,7 @@ class Mapper(object): for prop in self.__props.values(): prop.register_dependencies(uowcommit, *args, **kwargs) - def cascade_iterator(self, type, object, recursive=None, halt_on=None): + def cascade_iterator(self, type, state, recursive=None, halt_on=None): """Iterate each element and its mapper in an object graph, for all relations that meet the given cascade rule. @@ -1245,19 +1251,22 @@ class Mapper(object): The name of the cascade rule (i.e. save-update, delete, etc.) - object - The lead object instance. child items will be processed per + state + The lead InstanceState. child items will be processed per the relations defined for this object's mapper. recursive Used by the function for internal context during recursive calls, leave as None. + + the return value are object instances; this provides a strong + reference so that they don't fall out of scope immediately. """ if recursive is None: recursive=util.IdentitySet() for prop in self.__props.values(): - for (c, m) in prop.cascade_iterator(type, object, recursive, halt_on=halt_on): + for (c, m) in prop.cascade_iterator(type, state, recursive, halt_on=halt_on): yield (c, m) def get_select_mapper(self): @@ -1503,6 +1512,9 @@ Mapper.logger = logging.class_logger(Mapper) def has_identity(object): return hasattr(object, '_instance_key') +def _state_has_identity(state): + return '_instance_key' in state.dict + def has_mapper(object): """Return True if the given object has had a mapper association set up, either through loading, or via insertion in a session. @@ -1510,6 +1522,9 @@ def has_mapper(object): return hasattr(object, '_entity_name') +def _state_mapper(state): + return state.class_._class_state.mappers[state.dict.get('_entity_name', None)] + def object_mapper(object, entity_name=None, raiseerror=True): """Given an object, return the primary Mapper associated with the object instance. diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 8980a4498a..1e6d3ba7bf 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -54,13 +54,13 @@ class ColumnProperty(StrategizedProperty): def copy(self): return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns) - - def getattr(self, object, column): - return getattr(object, self.key) - - def setattr(self, object, value, column): - setattr(object, self.key, value) + + def getattr(self, state, column): + return getattr(state.class_, self.key).impl.get(state) + def setattr(self, state, value, column): + getattr(state.class_, self.key).impl.set(state, value, None) + def merge(self, session, source, dest, dont_load, _recursive): setattr(dest, self.key, getattr(source, self.key, None)) @@ -95,18 +95,21 @@ class CompositeProperty(ColumnProperty): def copy(self): return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns) - def getattr(self, object, column): - obj = getattr(object, self.key) + def getattr(self, state, column): + obj = getattr(state.class_, self.key).impl.get(state) return self.get_col_value(column, obj) - def setattr(self, object, value, column): - obj = getattr(object, self.key, None) + def setattr(self, state, value, column): + # TODO: test coverage for this method + obj = getattr(state.class_, self.key).impl.get(state) if obj is None: obj = self.composite_class(*[None for c in self.columns]) + getattr(state.class_, self.key).impl.set(state, obj, None) + for a, b in zip(self.columns, value.__composite_values__()): if a is column: setattr(obj, b, value) - + def get_col_value(self, column, value): for a, b in zip(self.columns, value.__composite_values__()): if a is column: @@ -319,13 +322,13 @@ class PropertyLoader(StrategizedProperty): def merge(self, session, source, dest, dont_load, _recursive): if not "merge" in self.cascade: return - childlist = attributes.get_history(source, self.key, passive=True) - if childlist is None: + instances = attributes.get_as_list(source._state, self.key, passive=True) + if not instances: return if self.uselist: # sets a blank collection according to the correct list class dest_list = attributes.init_collection(dest, self.key) - for current in list(childlist): + for current in instances: obj = session.merge(current, entity_name=self.mapper.entity_name, dont_load=dont_load, _recursive=_recursive) if obj is not None: if dont_load: @@ -333,7 +336,7 @@ class PropertyLoader(StrategizedProperty): else: dest_list.append_with_event(obj) else: - current = list(childlist)[0] + current = instances[0] if current is not None: obj = session.merge(current, entity_name=self.mapper.entity_name, dont_load=dont_load, _recursive=_recursive) if obj is not None: @@ -342,19 +345,21 @@ class PropertyLoader(StrategizedProperty): else: setattr(dest, self.key, obj) - def cascade_iterator(self, type, object, recursive, halt_on=None): + def cascade_iterator(self, type, state, recursive, halt_on=None): if not type in self.cascade: return passive = type != 'delete' or self.passive_deletes mapper = self.mapper.primary_mapper() - for c in attributes.get_as_list(object, self.key, passive=passive): - if c is not None and c not in recursive and (halt_on is None or not halt_on(c)): - if not isinstance(c, self.mapper.class_): - raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__))) - recursive.add(c) - yield (c, mapper) - for (c2, m) in mapper.cascade_iterator(type, c, recursive): - yield (c2, m) + instances = attributes.get_as_list(state, self.key, passive=passive) + if instances: + for c in instances: + if c is not None and c not in recursive and (halt_on is None or not halt_on(c)): + if not isinstance(c, self.mapper.class_): + raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__))) + recursive.add(c) + yield (c, mapper) + for (c2, m) in mapper.cascade_iterator(type, c._state, recursive): + yield (c2, m) def _get_target_class(self): """Return the target class of the relation, even if the diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 5f8602105d..435ed4c5d7 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -765,16 +765,16 @@ class Session(object): if attribute_names: self._validate_persistent(instance) - expire_instance(instance, attribute_names=attribute_names) + _expire_state(instance._state, attribute_names=attribute_names) else: # pre-fetch the full cascade since the expire is going to # remove associations cascaded = list(_cascade_iterator('refresh-expire', instance)) self._validate_persistent(instance) - expire_instance(instance, None) + _expire_state(instance._state, None) for (c, m) in cascaded: self._validate_persistent(c) - expire_instance(c, None) + _expire_state(c._state, None) def prune(self): """Removes unreferenced instances cached in the identity map. @@ -799,7 +799,7 @@ class Session(object): self._validate_persistent(instance) for c, m in [(instance, None)] + list(_cascade_iterator('expunge', instance)): if c in self: - self.uow._remove_deleted(c) + self.uow._remove_deleted(c._state) self._unattach(c) def save(self, instance, entity_name=None): @@ -812,7 +812,6 @@ class Session(object): The `entity_name` keyword argument will further qualify the specific ``Mapper`` used to handle this instance. """ - self._save_impl(instance, entity_name=entity_name) self._cascade_save_or_update(instance) @@ -1052,12 +1051,12 @@ class Session(object): result of True. """ - return instance in self.uow.new or (hasattr(instance, '_instance_key') and self.identity_map.get(instance._instance_key) is instance) + return instance._state in self.uow.new or (hasattr(instance, '_instance_key') and self.identity_map.get(instance._instance_key) is instance) def __iter__(self): """return an iterator of all instances which are pending or persistent within this Session.""" - return iter(list(self.uow.new) + self.uow.identity_map.values()) + return iter(list(self.uow.new.values()) + self.uow.identity_map.values()) def is_modified(self, instance, include_collections=True, passive=False): """return True if the given instance has modified attributes. @@ -1079,7 +1078,8 @@ class Session(object): for attr in attributes._managed_attributes(instance.__class__): if not include_collections and hasattr(attr.impl, 'get_collection'): continue - if attr.get_history(instance).is_modified(): + (added, unchanged, deleted) = attr.get_history(instance) + if added or deleted: return True return False @@ -1097,13 +1097,13 @@ class Session(object): is_modified() method. """) - deleted = property(lambda s:s.uow.deleted, + deleted = property(lambda s:util.IdentitySet(s.uow.deleted.values()), doc="A ``Set`` of all instances marked as 'deleted' within this ``Session``") - new = property(lambda s:s.uow.new, + new = property(lambda s:util.IdentitySet(s.uow.new.values()), doc="A ``Set`` of all instances marked as 'new' within this ``Session``.") -def expire_instance(instance, attribute_names): +def _expire_state(state, attribute_names): """standalone expire instance function. installs a callable with the given instance's _state @@ -1113,13 +1113,13 @@ def expire_instance(instance, attribute_names): If the list is None or blank, the entire instance is expired. """ - if instance._state.trigger is None: + if state.trigger is None: def load_attributes(instance, attribute_names): if object_session(instance).query(instance.__class__)._get(instance._instance_key, refresh_instance=instance, only_load_props=attribute_names) is None: raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance)) - instance._state.trigger = load_attributes + state.trigger = load_attributes - instance._state.expire_attributes(attribute_names) + state.expire_attributes(attribute_names) register_attribute = unitofwork.register_attribute @@ -1127,7 +1127,7 @@ _sessions = weakref.WeakValueDictionary() def _cascade_iterator(cascade, instance, **kwargs): mapper = _object_mapper(instance) - for (o, m) in mapper.cascade_iterator(cascade, instance, **kwargs): + for (o, m) in mapper.cascade_iterator(cascade, instance._state, **kwargs): yield o, m def object_session(instance): @@ -1143,4 +1143,4 @@ def object_session(instance): # Lazy initialization to avoid circular imports unitofwork.object_session = object_session from sqlalchemy.orm import mapper -mapper.expire_instance = expire_instance \ No newline at end of file +mapper._expire_state = _expire_state \ No newline at end of file diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index 8132c7e4a5..2d6328514c 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -61,7 +61,7 @@ class ClauseSynchronizer(object): source_column = binary.right else: if binary.left in foreign_keys: - source_column=binary.right + source_column = binary.right dest_column = binary.left elif binary.right in foreign_keys: source_column = binary.left @@ -94,15 +94,10 @@ class SyncRule(object): """An instruction indicating how to populate the objects on each side of a relationship. - In other words, if table1 column A is joined against table2 column + E.g. if table1 column A is joined against table2 column B, and we are a one-to-many from table1 to table2, a syncrule would say *take the A attribute from object1 and assign it to the B attribute on object2*. - - A rule contains the source mapper, the source column, destination - column, destination mapper in the case of a one/many relationship, - and the integer direction of this mapper relative to the - association in the case of a many to many relationship. """ def __init__(self, source_mapper, source_column, dest_column, dest_mapper=None, issecondary=None): @@ -123,26 +118,26 @@ class SyncRule(object): self._dest_primary_key = self.dest_mapper is not None and self.dest_column in self.dest_mapper._pks_by_table[self.dest_column.table] and not self.dest_mapper.allow_null_pks return self._dest_primary_key - def execute(self, source, dest, obj, child, clearkeys): + def execute(self, source, dest, parent, child, clearkeys): if source is None: if self.issecondary is False: - source = obj + source = parent elif self.issecondary is True: source = child if clearkeys or source is None: value = None clearkeys = True else: - value = self.source_mapper._get_attr_by_column(source, self.source_column) + value = self.source_mapper._get_state_attr_by_column(source, self.source_column) if isinstance(dest, dict): dest[self.dest_column.key] = value else: if clearkeys and self.dest_primary_key(): - raise exceptions.AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (str(self.dest_column), mapperutil.instance_str(dest))) + raise exceptions.AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (str(self.dest_column), mapperutil.state_str(dest))) if logging.is_debug_enabled(self.logger): - self.logger.debug("execute() instances: %s(%s)->%s(%s) ('%s')" % (mapperutil.instance_str(source), str(self.source_column), mapperutil.instance_str(dest), str(self.dest_column), value)) - self.dest_mapper._set_attr_by_column(dest, self.dest_column, value) + self.logger.debug("execute() instances: %s(%s)->%s(%s) ('%s')" % (mapperutil.state_str(source), str(self.source_column), mapperutil.state_str(dest), str(self.dest_column), value)) + self.dest_mapper._set_state_attr_by_column(dest, self.dest_column, value) SyncRule.logger = logging.class_logger(SyncRule) diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 6854ab7bc2..e4c65a2146 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -19,18 +19,18 @@ new, dirty, or deleted and provides the capability to flush all those changes at once. """ -import gc, StringIO, weakref +import StringIO, weakref from sqlalchemy import util, logging, topological, exceptions from sqlalchemy.orm import attributes, interfaces from sqlalchemy.orm import util as mapperutil -from sqlalchemy.orm.mapper import object_mapper +from sqlalchemy.orm.mapper import object_mapper, _state_mapper # Load lazily object_session = None class UOWEventHandler(interfaces.AttributeExtension): - """An event handler added to all class attributes which handles - session operations. + """An event handler added to all relation attributes which handles + session cascade operations. """ def __init__(self, key, class_, cascade=None): @@ -64,10 +64,18 @@ class UOWEventHandler(interfaces.AttributeExtension): sess.save_or_update(newvalue, entity_name=ename) def register_attribute(class_, key, *args, **kwargs): + """overrides attributes.register_attribute() to add UOW event handlers + to new InstrumentedAttributes. + """ + cascade = kwargs.pop('cascade', None) - extension = util.to_list(kwargs.pop('extension', None) or []) - extension.insert(0, UOWEventHandler(key, class_, cascade=cascade)) - kwargs['extension'] = extension + useobject = kwargs.get('useobject', False) + if useobject: + # for object-holding attributes, instrument UOWEventHandler + # to process per-attribute cascades + extension = util.to_list(kwargs.pop('extension', None) or []) + extension.insert(0, UOWEventHandler(key, class_, cascade=cascade)) + kwargs['extension'] = extension return attributes.register_attribute(class_, key, *args, **kwargs) @@ -86,56 +94,49 @@ class UnitOfWork(object): else: self.identity_map = attributes.StrongInstanceDict() - self.new = util.IdentitySet() #OrderedSet() - self.deleted = util.IdentitySet() + self.new = {} # InstanceState->object, strong refs object + self.deleted = {} # same self.logger = logging.instance_logger(self, echoflag=session.echo_uow) - def _remove_deleted(self, obj): - if hasattr(obj, "_instance_key"): - del self.identity_map[obj._instance_key] - try: - self.deleted.remove(obj) - except KeyError: - pass - try: - self.new.remove(obj) - except KeyError: - pass + def _remove_deleted(self, state): + if '_instance_key' in state.dict: + del self.identity_map[state.dict['_instance_key']] + self.deleted.pop(state, None) + self.new.pop(state, None) - def _is_valid(self, obj): - if (hasattr(obj, '_instance_key') and obj._instance_key not in self.identity_map) or \ - (not hasattr(obj, '_instance_key') and obj not in self.new): - return False + def _is_valid(self, state): + if '_instance_key' in state.dict: + return state.dict['_instance_key'] in self.identity_map else: - return True + return state in self.new - def _register_clean(self, obj): + def _register_clean(self, state): """register the given object as 'clean' (i.e. persistent) within this unit of work, after a save operation has taken place.""" - - if obj in self.new: - self.new.remove(obj) - if not hasattr(obj, '_instance_key'): - mapper = object_mapper(obj) - obj._instance_key = mapper.identity_key_from_instance(obj) - if hasattr(obj, '_sa_insert_order'): - delattr(obj, '_sa_insert_order') - self.identity_map[obj._instance_key] = obj - obj._state.commit_all() + + if '_instance_key' not in state.dict: + mapper = _state_mapper(state) + state.dict['_instance_key'] = mapper._identity_key_from_state(state) + if hasattr(state, 'insert_order'): + delattr(state, 'insert_order') + self.identity_map[state.dict['_instance_key']] = state.obj() + state.commit_all() + # remove from new last, might be the last strong ref + self.new.pop(state, None) def register_new(self, obj): """register the given object as 'new' (i.e. unsaved) within this unit of work.""" if hasattr(obj, '_instance_key'): raise exceptions.InvalidRequestError("Object '%s' already has an identity - it can't be registered as new" % repr(obj)) - if obj not in self.new: - self.new.add(obj) - obj._sa_insert_order = len(self.new) + if obj._state not in self.new: + self.new[obj._state] = obj + obj._state.insert_order = len(self.new) def register_deleted(self, obj): """register the given persistent object as 'to be deleted' within this unit of work.""" - self.deleted.add(obj) + self.deleted[obj._state] = obj def locate_dirty(self): """return a set of all persistent instances within this unit of work which @@ -144,14 +145,13 @@ class UnitOfWork(object): # a little bit of inlining for speed return util.IdentitySet([x for x in self.identity_map.values() - if x not in self.deleted + if x._state not in self.deleted and ( x._state.modified or (x.__class__._class_state.has_mutable_scalars and x.state.is_modified()) ) ]) - def flush(self, session, objects=None): """create a dependency tree of all pending SQL operations within this unit of work and execute.""" @@ -165,10 +165,13 @@ class UnitOfWork(object): or (x.class_._class_state.has_mutable_scalars and x.is_modified()) ] - if len(dirty) == 0 and len(self.deleted) == 0 and len(self.new) == 0: + if not dirty and not self.deleted and not self.new: return - - dirty = util.IdentitySet([x.obj() for x in dirty]).difference(self.deleted) + + deleted = util.Set(self.deleted) + new = util.Set(self.new) + + dirty = util.Set(dirty).difference(deleted) flush_context = UOWTransaction(self, session) @@ -176,27 +179,27 @@ class UnitOfWork(object): session.extension.before_flush(session, flush_context, objects) # create the set of all objects we want to operate upon - if objects is not None: + if objects: # specific list passed in - objset = util.IdentitySet(objects) + objset = util.Set([o._state for o in objects]) else: # or just everything - objset = util.IdentitySet(self.identity_map.values()).union(self.new) + objset = util.Set(self.identity_map.all_states()).union(new) # store objects whose fate has been decided - processed = util.IdentitySet() + processed = util.Set() # put all saves/updates into the flush context. detect top-level orphans and throw them into deleted. - for obj in self.new.union(dirty).intersection(objset).difference(self.deleted): - if obj in processed: + for state in new.union(dirty).intersection(objset).difference(deleted): + if state in processed: continue - flush_context.register_object(obj, isdelete=object_mapper(obj)._is_orphan(obj)) - processed.add(obj) + flush_context.register_object(state, isdelete=_state_mapper(state)._is_orphan(state.obj())) + processed.add(state) # put all remaining deletes into the flush context. - for obj in self.deleted.intersection(objset).difference(processed): - flush_context.register_object(obj, isdelete=True) + for state in deleted.intersection(objset).difference(processed): + flush_context.register_object(state, isdelete=True) if len(flush_context.tasks) == 0: return @@ -236,7 +239,6 @@ class UnitOfWork(object): dirty = self.locate_dirty() keepers = weakref.WeakValueDictionary(self.identity_map) self.identity_map.clear() - gc.collect() self.identity_map.update(keepers) return ref_count - len(self.identity_map) @@ -268,32 +270,23 @@ class UOWTransaction(object): self.logger = logging.instance_logger(self, echoflag=session.echo_uow) - def register_object(self, obj, isdelete = False, listonly = False, postupdate=False, post_update_cols=None, **kwargs): - """Add an object to this ``UOWTransaction`` to be updated in the database. - - This operation has the combined effect of locating/creating an appropriate - ``UOWTask`` object, and calling its ``append()`` method which then locates/creates - an appropriate ``UOWTaskElement`` object. - """ - - #print "REGISTER", repr(obj), repr(getattr(obj, '_instance_key', None)), str(isdelete), str(listonly) - + def register_object(self, state, isdelete = False, listonly = False, postupdate=False, post_update_cols=None, **kwargs): # if object is not in the overall session, do nothing - if not self.uow._is_valid(obj): + if not self.uow._is_valid(state): if self._should_log_debug: - self.logger.debug("object %s not part of session, not registering for flush" % (mapperutil.instance_str(obj))) + self.logger.debug("object %s not part of session, not registering for flush" % (mapperutil.state_str(state))) return if self._should_log_debug: - self.logger.debug("register object for flush: %s isdelete=%s listonly=%s postupdate=%s" % (mapperutil.instance_str(obj), isdelete, listonly, postupdate)) + self.logger.debug("register object for flush: %s isdelete=%s listonly=%s postupdate=%s" % (mapperutil.state_str(state), isdelete, listonly, postupdate)) - mapper = object_mapper(obj) + mapper = _state_mapper(state) + task = self.get_task_by_mapper(mapper) if postupdate: - task.append_postupdate(obj, post_update_cols) - return - - task.append(obj, listonly, isdelete=isdelete, **kwargs) + task.append_postupdate(state, post_update_cols) + else: + task.append(state, listonly, isdelete=isdelete, **kwargs) def set_row_switch(self, obj): """mark a deleted object as a 'row switch'. @@ -303,7 +296,7 @@ class UOWTransaction(object): """ mapper = object_mapper(obj) task = self.get_task_by_mapper(mapper) - taskelement = task._objects[id(obj)] + taskelement = task._objects[obj._state] taskelement.isdelete = "rowswitch" def unregister_object(self, obj): @@ -313,16 +306,21 @@ class UOWTransaction(object): no further operations occur upon the instance.""" mapper = object_mapper(obj) task = self.get_task_by_mapper(mapper) - if id(obj) in task._objects: - task.delete(obj) + if obj._state in task._objects: + task.delete(obj._state) def is_deleted(self, obj): """return true if the given object is marked as deleted within this UOWTransaction.""" mapper = object_mapper(obj) task = self.get_task_by_mapper(mapper) - return task.is_deleted(obj) + return task.is_deleted(obj._state) + def state_is_deleted(self, state): + mapper = _state_mapper(state) + task = self.get_task_by_mapper(mapper) + return task.is_deleted(state) + def get_task_by_mapper(self, mapper, dontcreate=False): """return UOWTask element corresponding to the given mapper. @@ -339,13 +337,11 @@ class UOWTransaction(object): if base_mapper in self.tasks: base_task = self.tasks[base_mapper] else: - base_task = UOWTask(self, base_mapper) - self.tasks[base_mapper] = base_task + self.tasks[base_mapper] = base_task = UOWTask(self, base_mapper) base_mapper.register_dependencies(self) if mapper not in self.tasks: - task = UOWTask(self, mapper, base_task=base_task) - self.tasks[mapper] = task + self.tasks[mapper] = task = UOWTask(self, mapper, base_task=base_task) mapper.register_dependencies(self) else: task = self.tasks[mapper] @@ -360,7 +356,7 @@ class UOWTransaction(object): by another. """ - # correct for primary mapper (the mapper offcially associated with the class) + # correct for primary mapper # also convert to the "base mapper", the parentmost task at the top of an inheritance chain # dependency sorting is done via non-inheriting mappers only, dependencies between mappers # in the same inheritance chain is done at the per-object level @@ -370,29 +366,12 @@ class UOWTransaction(object): self.dependencies.add((mapper, dependency)) def register_processor(self, mapper, processor, mapperfrom): - """register a dependency processor object, corresponding to dependencies between + """register a dependency processor, corresponding to dependencies between the two given mappers. - In reality, the processor is an instance of ``dependency.DependencyProcessor`` - and is registered as a result of the ``mapper.register_dependencies()`` call in - ``get_task_by_mapper()``. - - The dependency processor supports the methods ``preprocess_dependencies()`` and - ``process_dependencies()``, which - perform operations on a list of instances that have a dependency relationship - with some other instance. The operations include adding items to the UOW - corresponding to some cascade operations, issuing inserts/deletes on - association tables, and synchronzing foreign key values between related objects - before the dependent object is operated upon at the SQL level. """ - # when the task from "mapper" executes, take the objects from the task corresponding - # to "mapperfrom"'s list of save/delete objects, and send them to "processor" - # for dependency processing - - #print "registerprocessor", str(mapper), repr(processor), repr(processor.key), str(mapperfrom) - - # correct for primary mapper (the mapper offcially associated with the class) + # correct for primary mapper mapper = mapper.primary_mapper() mapperfrom = mapperfrom.primary_mapper() @@ -404,11 +383,12 @@ class UOWTransaction(object): def execute(self): """Execute this UOWTransaction. - This will organize all collected UOWTasks into a toplogically-sorted - dependency tree, which is then traversed using the traversal scheme + This will organize all collected UOWTasks into a dependency-sorted + list which is then traversed using the traversal scheme encoded in the UOWExecutor class. Operations to mappers and dependency processors are fired off in order to issue SQL to the database and - to maintain instance state during the execution.""" + synchronize instance attributes with database values and related + foreign key values.""" # pre-execute dependency processors. this process may # result in new tasks, objects and/or dependency processors being added, @@ -424,17 +404,19 @@ class UOWTransaction(object): if not ret: break - head = self._sort_dependencies() + tasks = self._sort_dependencies() if self._should_log_info: - if head is None: - self.logger.info("Task dump: None") - else: - self.logger.info("Task dump:\n" + head.dump()) - if head is not None: - UOWExecutor().execute(self, head) + self.logger.info("Task dump:\n" + self._dump(tasks)) + UOWExecutor().execute(self, tasks) if self._should_log_info: self.logger.info("Execute Complete") + def _dump(self, tasks): + buf = StringIO.StringIO() + import uowdumper + uowdumper.UOWDumper(tasks, buf) + return buf.getvalue() + def post_exec(self): """mark processed objects as clean / deleted after a successful flush(). @@ -444,49 +426,31 @@ class UOWTransaction(object): for task in self.tasks.values(): for elem in task.elements: - if elem.obj is None: + if elem.state is None: continue if elem.isdelete: - self.uow._remove_deleted(elem.obj) + self.uow._remove_deleted(elem.state) else: - self.uow._register_clean(elem.obj) + self.uow._register_clean(elem.state) def _sort_dependencies(self): - """Create a hierarchical tree of dependent UOWTask instances. - - The root UOWTask is returned. - - Cyclical relationships - within the toplogical sort are further broken down into new - temporary UOWTask insances which represent smaller sub-groups of objects - that would normally belong to a single UOWTask. - - """ - - def sort_hier(node): - if node is None: - return None - task = self.get_task_by_mapper(node.item) - if node.cycles is not None: - tasks = [] - for n in node.cycles: - tasks.append(self.get_task_by_mapper(n.item)) - task = task._sort_circular_dependencies(self, tasks) - for child in node.children: - t = sort_hier(child) - if t is not None: - task.childtasks.append(t) - return task + nodes = topological.sort_with_cycles(self.dependencies, + [t.mapper for t in self.tasks.values() if t.base_task is t] + ) + + ret = [] + for item, cycles in nodes: + task = self.get_task_by_mapper(item) + if cycles: + for t in task._sort_circular_dependencies(self, [self.get_task_by_mapper(i) for i in cycles]): + ret.append(t) + else: + ret.append(task) - # get list of base mappers - mappers = [t.mapper for t in self.tasks.values() if t.base_task is t] - head = topological.QueueDependencySorter(self.dependencies, mappers).sort(allow_cycles=True) if self._should_log_debug: self.logger.debug("Dependent tuples:\n" + "\n".join(["(%s->%s)" % (d[0].class_.__name__, d[1].class_.__name__) for d in self.dependencies])) - self.logger.debug("Dependency sort:\n"+ str(head)) - task = sort_hier(head) - return task - + self.logger.debug("Dependency sort:\n"+ str(ret)) + return ret class UOWTask(object): """Represents all of the objects in the UOWTransaction which correspond to @@ -495,7 +459,6 @@ class UOWTask(object): """ def __init__(self, uowtransaction, mapper, base_task=None): - # the transaction owning this UOWTask self.uowtransaction = uowtransaction # base_task is the UOWTask which represents the "base mapper" @@ -514,31 +477,11 @@ class UOWTask(object): # the Mapper which this UOWTask corresponds to self.mapper = mapper - # a dictionary mapping object instances to a corresponding UOWTaskElement. - # Each UOWTaskElement represents one object instance which is to be saved or - # deleted by this UOWTask's Mapper. - # in the case of the row-based "cyclical sort", the UOWTaskElement may - # also reference further UOWTasks which are dependent on that UOWTaskElement. + # mapping of InstanceState -> UOWTaskElement self._objects = {} - # a set of UOWDependencyProcessor instances, which are executed after saves and - # before deletes, to synchronize data between dependent objects as well as to - # ensure that relationship cascades populate the flush() process with all - # appropriate objects. self._dependencies = util.Set() - - # a list of UOWTasks which are sub-nodes to this UOWTask. this list - # is populated during the dependency sorting operation. - self.childtasks = [] - - # a list of UOWDependencyProcessor instances - # which derive from the UOWDependencyProcessor instances present in a - # corresponding UOWTask's "_dependencies" set. This collection is populated - # during a row-based cyclical sorting operation and only corresponds to - # new UOWTask instances created during this operation, which are also local - # to the dependency graph (i.e. they are not present in the get_task_by_mapper() - # collection). - self._cyclical_dependencies = util.Set() + self.cyclical_dependencies = util.Set() def polymorphic_tasks(self): """return an iterator of UOWTask objects corresponding to the inheritance sequence @@ -569,65 +512,29 @@ class UOWTask(object): t = self.base_task._inheriting_tasks.get(mapper, None) if t is not None: yield t - + def is_empty(self): """return True if this UOWTask is 'empty', meaning it has no child items. - + used only for debugging output. """ - - return len(self._objects) == 0 and len(self._dependencies) == 0 and len(self.childtasks) == 0 - def append(self, obj, listonly = False, childtask = None, isdelete = False): - """Append an object to this task to be persisted or deleted. - - The actual record added to the ``UOWTask`` is a ``UOWTaskElement`` object - corresponding to the given instance. If a corresponding ``UOWTaskElement`` already - exists within this ``UOWTask``, its state is updated with the given - keyword arguments as appropriate. - - 'isdelete' when True indicates the operation will be a "delete" - operation (i.e. DELETE), otherwise is a "save" operation (i.e. INSERT/UPDATE). - a ``UOWTaskElement`` marked as "save" which receives the "isdelete" flag will - be marked as deleted, but the reverse operation does not apply (i.e. goes from - "delete" to being "not delete"). - - `listonly` indicates that the object does not require a delete - or save operation, but does require dependency operations to be - executed. For example, adding a child object to a parent via a - one-to-many relationship requires that a ``OneToManyDP`` object - corresponding to the parent's mapper synchronize the instance's primary key - value into the foreign key attribute of the child object, even though - no changes need be persisted on the parent. - - a listonly object may be "upgraded" to require a save/delete operation - by a subsequent append() of the same object instance with the `listonly` - flag set to False. once the flag is set to false, it stays that way - on the ``UOWTaskElement``. + return not self._objects and not self._dependencies + + def append(self, state, listonly=False, isdelete=False): + if state not in self._objects: + self._objects[state] = rec = UOWTaskElement(state) + else: + rec = self._objects[state] - `childtask` is an optional ``UOWTask`` element represending operations which - are dependent on the parent ``UOWTaskElement``. This flag is only used on - `UOWTask` objects created within the "cyclical sort" part of the hierarchical - sort, which generates a dependency tree of individual instances instead of - mappers when cycles between mappers are detected. - """ + rec.update(listonly, isdelete) + + def _append_cyclical_childtask(self, task): + if "cyclical" not in self._objects: + self._objects["cyclical"] = UOWTaskElement(None) + self._objects["cyclical"].childtasks.append(task) - try: - rec = self._objects[id(obj)] - retval = False - except KeyError: - rec = UOWTaskElement(obj) - self._objects[id(obj)] = rec - retval = True - if not listonly: - rec.listonly = False - if childtask: - rec.childtasks.append(childtask) - if isdelete: - rec.isdelete = True - return retval - - def append_postupdate(self, obj, post_update_cols): + def append_postupdate(self, state, post_update_cols): """issue a 'post update' UPDATE statement via this object's mapper immediately. this operation is used only with relations that specify the `post_update=True` @@ -637,31 +544,27 @@ class UOWTask(object): # postupdates are UPDATED immeditely (for now) # convert post_update_cols list to a Set so that __hashcode__ is used to compare columns # instead of __eq__ - self.mapper.save_obj([obj], self.uowtransaction, postupdate=True, post_update_cols=util.Set(post_update_cols)) - return True + self.mapper.save_obj([state], self.uowtransaction, postupdate=True, post_update_cols=util.Set(post_update_cols)) def delete(self, obj): """remove the given object from this UOWTask, if present.""" - - try: - del self._objects[id(obj)] - except KeyError: - pass - def __contains__(self, obj): + self._objects.pop(obj._state, None) + + def __contains__(self, state): """return True if the given object is contained within this UOWTask or inheriting tasks.""" for task in self.polymorphic_tasks(): - if id(obj) in task._objects: + if state in task._objects: return True else: return False - def is_deleted(self, obj): + def is_deleted(self, state): """return True if the given object is marked as to be deleted within this UOWTask.""" try: - return self._objects[id(obj)].isdelete + return self._objects[state].isdelete except KeyError: return False @@ -685,20 +588,16 @@ class UOWTask(object): polymorphic_todelete_elements = property(lambda self:[rec for rec in self.polymorphic_elements if rec.isdelete]) - polymorphic_tosave_objects = property(lambda self:[rec.obj for rec in self.polymorphic_elements - if rec.obj is not None and not rec.listonly and rec.isdelete is False]) + polymorphic_tosave_objects = property(lambda self:[rec.state for rec in self.polymorphic_elements + if rec.state is not None and not rec.listonly and rec.isdelete is False]) - polymorphic_todelete_objects = property(lambda self:[rec.obj for rec in self.polymorphic_elements - if rec.obj is not None and not rec.listonly and rec.isdelete is True]) + polymorphic_todelete_objects = property(lambda self:[rec.state for rec in self.polymorphic_elements + if rec.state is not None and not rec.listonly and rec.isdelete is True]) dependencies = property(lambda self:self._dependencies) - cyclical_dependencies = property(lambda self:self._cyclical_dependencies) - polymorphic_dependencies = _polymorphic_collection(lambda task:task.dependencies) - polymorphic_childtasks = _polymorphic_collection(lambda task:task.childtasks) - polymorphic_cyclical_dependencies = _polymorphic_collection(lambda task:task.cyclical_dependencies) def _sort_circular_dependencies(self, trans, cycles): @@ -713,29 +612,19 @@ class UOWTask(object): """ allobjects = [] for task in cycles: - allobjects += [e.obj for e in task.polymorphic_elements] + allobjects += [e.state for e in task.polymorphic_elements] tuples = [] cycles = util.Set(cycles) - #print "BEGIN CIRC SORT-------" - #print "PRE-CIRC:" - #print list(cycles) #[0].dump() - - # dependency processors that arent part of the cyclical thing - # get put here extradeplist = [] - - # organizes a set of new UOWTasks that will be assembled into - # the final tree, for the purposes of holding new UOWDependencyProcessors - # which process small sub-sections of dependent parent/child operations dependencies = {} - def get_dependency_task(obj, depprocessor): + def get_dependency_task(state, depprocessor): try: - dp = dependencies[id(obj)] + dp = dependencies[state] except KeyError: - dp = dependencies.setdefault(id(obj), {}) + dp = dependencies.setdefault(state, {}) try: l = dp[depprocessor] except KeyError: @@ -763,8 +652,8 @@ class UOWTask(object): for task in cycles: for subtask in task.polymorphic_tasks(): for taskelement in subtask.elements: - obj = taskelement.obj - object_to_original_task[id(obj)] = subtask + state = taskelement.state + object_to_original_task[state] = subtask for dep in deps_by_targettask.get(subtask, []): # is this dependency involved in one of the cycles ? if not dependency_in_cycles(dep): @@ -773,13 +662,14 @@ class UOWTask(object): isdelete = taskelement.isdelete # list of dependent objects from this object - childlist = dep.get_object_dependencies(obj, trans, passive=True) - if childlist is None: + (added, unchanged, deleted) = dep.get_object_dependencies(state, trans, passive=True) + if not added and not unchanged and not deleted: continue + # the task corresponding to saving/deleting of those dependent objects childtask = trans.get_task_by_mapper(processor.mapper) - childlist = childlist.added_items() + childlist.unchanged_items() + childlist.deleted_items() + childlist = added + unchanged + deleted for o in childlist: # other object is None. this can occur if the relationship is many-to-one @@ -793,46 +683,42 @@ class UOWTask(object): # task if o not in childtask: childtask.append(o, listonly=True) - object_to_original_task[id(o)] = childtask + object_to_original_task[o] = childtask # create a tuple representing the "parent/child" - whosdep = dep.whose_dependent_on_who(obj, o) + whosdep = dep.whose_dependent_on_who(state, o) if whosdep is not None: # append the tuple to the partial ordering. tuples.append(whosdep) # create a UOWDependencyProcessor representing this pair of objects. # append it to a UOWTask - if whosdep[0] is obj: + if whosdep[0] is state: get_dependency_task(whosdep[0], dep).append(whosdep[0], isdelete=isdelete) else: get_dependency_task(whosdep[0], dep).append(whosdep[1], isdelete=isdelete) else: - get_dependency_task(obj, dep).append(obj, isdelete=isdelete) + get_dependency_task(state, dep).append(obj, isdelete=isdelete) - #print "TUPLES", tuples - #print "ALLOBJECTS", allobjects - head = topological.QueueDependencySorter(tuples, allobjects).sort() - - # create a tree of UOWTasks corresponding to the tree of object instances - # created by the DependencySorter + head = topological.sort_as_tree(tuples, allobjects) used_tasks = util.Set() def make_task_tree(node, parenttask, nexttasks): - originating_task = object_to_original_task[id(node.item)] + (state, cycles, children) = node + originating_task = object_to_original_task[state] used_tasks.add(originating_task) t = nexttasks.get(originating_task, None) if t is None: t = UOWTask(self.uowtransaction, originating_task.mapper) nexttasks[originating_task] = t - parenttask.append(None, listonly=False, isdelete=originating_task._objects[id(node.item)].isdelete, childtask=t) - t.append(node.item, originating_task._objects[id(node.item)].listonly, isdelete=originating_task._objects[id(node.item)].isdelete) + parenttask._append_cyclical_childtask(t) + t.append(state, originating_task._objects[state].listonly, isdelete=originating_task._objects[state].isdelete) - if id(node.item) in dependencies: - for depprocessor, deptask in dependencies[id(node.item)].iteritems(): + if state in dependencies: + for depprocessor, deptask in dependencies[state].iteritems(): t.cyclical_dependencies.add(depprocessor.branch(deptask)) nd = {} - for n in node.children: + for n in children: t2 = make_task_tree(n, t, nd) return t @@ -842,43 +728,28 @@ class UOWTask(object): for d in extradeplist: t._dependencies.add(d) - # if we have a head from the dependency sort, assemble child nodes - # onto the tree. note this only occurs if there were actual objects - # to be saved/deleted. if head is not None: make_task_tree(head, t, {}) + ret = [t] for t2 in cycles: - # tasks that were in the cycle but did not get assembled - # into the tree, add them as child tasks. these tasks - # will have no "save" or "delete" members, but may have dependency - # processors that operate upon other tasks outside of the cycle. if t2 not in used_tasks and t2 is not self: - # the task must be copied into a "cyclical" task, so that polymorphic - # rules dont fire off. this ensures that the task will have no "save" - # or "delete" members due to inheriting mappers which contain tasks + # add tasks that were in the cycle, but didnt get assembled + # into the cyclical tree, to the start of the list + # TODO: no test coverage for this !! localtask = UOWTask(self.uowtransaction, t2.mapper) - for obj in t2.elements: - localtask.append(obj, t2.listonly, isdelete=t2._objects[id(obj)].isdelete) + for state in t2.elements: + localtask.append(obj, t2.listonly, isdelete=t2._objects[state].isdelete) for dep in t2.dependencies: localtask._dependencies.add(dep) - t.childtasks.insert(0, localtask) - - return t - - def dump(self): - """return a string representation of this UOWTask and its - full dependency graph.""" + ret.insert(0, localtask) - buf = StringIO.StringIO() - import uowdumper - uowdumper.UOWDumper(self, buf) - return buf.getvalue() + return ret def __repr__(self): if self.mapper is not None: if self.mapper.__class__.__name__ == 'Mapper': - name = self.mapper.class_.__name__ + "/" + self.mapper.local_table.name + name = self.mapper.class_.__name__ + "/" + self.mapper.local_table.description else: name = repr(self.mapper) else: @@ -892,58 +763,36 @@ class UOWTaskElement(object): just part of the transaction as a placeholder for further dependencies (i.e. 'listonly'). - In the case of a ``UOWTaskElement`` present within an instance-level - graph formed due to cycles within the mapper-level graph, may also store a list of - childtasks, further UOWTasks containing objects dependent on this - element's object instance. + may also store additional sub-UOWTasks. """ - def __init__(self, obj): - self.obj = obj - self.__listonly = True + def __init__(self, state): + self.state = state + self.listonly = True self.childtasks = [] - self.__isdelete = False + self.isdelete = False self.__preprocessed = {} - - def _get_listonly(self): - return self.__listonly - def _set_listonly(self, value): - """Set_listonly is a one-way setter, will only go from True to False.""" - - if not value and self.__listonly: - self.__listonly = False - self.clear_preprocessed() - - def _get_isdelete(self): - return self.__isdelete - - def _set_isdelete(self, value): - if self.__isdelete is not value: - self.__isdelete = value - self.clear_preprocessed() - - listonly = property(_get_listonly, _set_listonly) - isdelete = property(_get_isdelete, _set_isdelete) + def update(self, listonly, isdelete): + if not listonly and self.listonly: + self.listonly = False + self.__preprocessed.clear() + if isdelete and not self.isdelete: + self.isdelete = True + self.__preprocessed.clear() def mark_preprocessed(self, processor): """Mark this element as *preprocessed* by a particular ``UOWDependencyProcessor``. - Preprocessing is the step which sweeps through all the - relationships on all the objects in the flush transaction and - adds other objects which are also affected. The actual logic is - part of ``UOWTransaction.execute()``. - - The preprocessing operations - are determined in part by the cascade rules indicated on a relationship, - and in part based on the normal semantics of relationships. - In some cases it can switch an object's state from *tosave* to *todelete*. - - Changes to the state of this ``UOWTaskElement`` will reset all - *preprocessed* flags, causing it to be preprocessed again. - When all ``UOWTaskElements have been fully preprocessed by all - UOWDependencyProcessors, then the topological sort can be - done. + Preprocessing is used by dependency.py to apply + flush-time cascade rules to relations and bring all + required objects into the flush context. + + each processor as marked as "processed" when complete, however + changes to the state of this UOWTaskElement will reset + the list of completed processors, so that they + execute again, until no new objects or state changes + are brought in. """ self.__preprocessed[processor] = True @@ -951,9 +800,6 @@ class UOWTaskElement(object): def is_preprocessed(self, processor): return self.__preprocessed.get(processor, False) - def clear_preprocessed(self): - self.__preprocessed.clear() - def __repr__(self): return "UOWTaskElement/%d: %s/%d %s" % (id(self), self.obj.__class__.__name__, id(self.obj), (self.listonly and 'listonly' or (self.isdelete and 'delete' or 'save')) ) @@ -998,15 +844,15 @@ class UOWDependencyProcessor(object): def getobj(elem): elem.mark_preprocessed(self) - return elem.obj + return elem.state ret = False - elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if elem.obj is not None and not elem.is_preprocessed(self)] + elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if elem.state is not None and not elem.is_preprocessed(self)] if elements: ret = True self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False) - elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if elem.obj is not None and not elem.is_preprocessed(self)] + elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if elem.state is not None and not elem.is_preprocessed(self)] if elements: ret = True self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True) @@ -1016,14 +862,14 @@ class UOWDependencyProcessor(object): """process all objects contained within this ``UOWDependencyProcessor``s target task.""" if not delete: - self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.polymorphic_tosave_elements if elem.obj is not None], trans, delete=False) + self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_tosave_elements if elem.state is not None], trans, delete=False) else: - self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.polymorphic_todelete_elements if elem.obj is not None], trans, delete=True) + self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_todelete_elements if elem.state is not None], trans, delete=True) - def get_object_dependencies(self, obj, trans, passive): - return self.processor.get_object_dependencies(obj, trans, passive=passive) + def get_object_dependencies(self, state, trans, passive): + return self.processor.get_object_dependencies(state, trans, passive=passive) - def whose_dependent_on_who(self, obj, o): + def whose_dependent_on_who(self, state1, state2): """establish which object is operationally dependent amongst a parent/child using the semantics stated by the dependency processor. @@ -1032,7 +878,7 @@ class UOWDependencyProcessor(object): """ - return self.processor.whose_dependent_on_who(obj, o) + return self.processor.whose_dependent_on_who(state1, state2) def branch(self, task): """create a copy of this ``UOWDependencyProcessor`` against a new ``UOWTask`` object. @@ -1048,11 +894,13 @@ class UOWDependencyProcessor(object): class UOWExecutor(object): """Encapsulates the execution traversal of a UOWTransaction structure.""" - def execute(self, trans, task, isdelete=None): + def execute(self, trans, tasks, isdelete=None): if isdelete is not True: - self.execute_save_steps(trans, task) + for task in tasks: + self.execute_save_steps(trans, task) if isdelete is not False: - self.execute_delete_steps(trans, task) + for task in util.reversed(tasks): + self.execute_delete_steps(trans, task) def save_objects(self, trans, task): task.mapper.save_obj(task.polymorphic_tosave_objects, trans) @@ -1069,11 +917,9 @@ class UOWExecutor(object): self.execute_per_element_childtasks(trans, task, False) self.execute_dependencies(trans, task, False) self.execute_dependencies(trans, task, True) - self.execute_childtasks(trans, task, False) - + def execute_delete_steps(self, trans, task): self.execute_cyclical_dependencies(trans, task, True) - self.execute_childtasks(trans, task, True) self.execute_per_element_childtasks(trans, task, True) self.delete_objects(trans, task) @@ -1085,10 +931,6 @@ class UOWExecutor(object): for dep in util.reversed(list(task.polymorphic_dependencies)): self.execute_dependency(trans, dep, True) - def execute_childtasks(self, trans, task, isdelete=None): - for child in task.polymorphic_childtasks: - self.execute(trans, child, isdelete) - def execute_cyclical_dependencies(self, trans, task, isdelete): for dep in task.polymorphic_cyclical_dependencies: self.execute_dependency(trans, dep, isdelete) @@ -1099,5 +941,5 @@ class UOWExecutor(object): def execute_element_childtasks(self, trans, element, isdelete): for child in element.childtasks: - self.execute(trans, child, isdelete) + self.execute(trans, [child], isdelete) diff --git a/lib/sqlalchemy/orm/uowdumper.py b/lib/sqlalchemy/orm/uowdumper.py index 83bd63f346..ba6d09261d 100644 --- a/lib/sqlalchemy/orm/uowdumper.py +++ b/lib/sqlalchemy/orm/uowdumper.py @@ -8,58 +8,53 @@ from sqlalchemy.orm import unitofwork from sqlalchemy.orm import util as mapperutil +from sqlalchemy import util class UOWDumper(unitofwork.UOWExecutor): - def __init__(self, task, buf, verbose=False): + def __init__(self, tasks, buf, verbose=False): self.verbose = verbose self.indent = 0 - self.task = task + self.tasks = tasks self.buf = buf - self.starttask = task self.headers = {} - self.execute(None, task) + self.execute(None, tasks) - def execute(self, trans, task, isdelete=None): - oldstarttask = self.starttask - oldheaders = self.headers - self.starttask = task - self.headers = {} + def execute(self, trans, tasks, isdelete=None): + if isdelete is not True: + for task in tasks: + self._execute(trans, task, False) + if isdelete is not False: + for task in util.reversed(tasks): + self._execute(trans, task, True) + + def _execute(self, trans, task, isdelete): try: i = self._indent() if i: - i += "-" - #i = i[0:-1] + "-" - self.buf.write(self._indent() + "\n") + i = i[:-1] + "+-" self.buf.write(i + " " + self._repr_task(task)) self.buf.write(" (" + (isdelete and "delete " or "save/update ") + "phase) \n") self.indent += 1 - super(UOWDumper, self).execute(trans, task, isdelete) + super(UOWDumper, self).execute(trans, [task], isdelete) finally: self.indent -= 1 - if self.starttask.is_empty(): - self.buf.write(self._indent() + " |- (empty task)\n") - else: - self.buf.write(self._indent() + " |----\n") - self.buf.write(self._indent() + "\n") - self.starttask = oldstarttask - self.headers = oldheaders def save_objects(self, trans, task): # sort elements to be inserted by insert order def comparator(a, b): - if a.obj is None: + if a.state is None: x = None - elif not hasattr(a.obj, '_sa_insert_order'): + elif not hasattr(a.state, 'insert_order'): x = None else: - x = a.obj._sa_insert_order - if b.obj is None: + x = a.state.insert_order + if b.state is None: y = None - elif not hasattr(b.obj, '_sa_insert_order'): + elif not hasattr(b.state, 'insert_order'): y = None else: - y = b.obj._sa_insert_order + y = b.state.insert_order return cmp(x, y) l = list(task.polymorphic_tosave_elements) @@ -68,7 +63,7 @@ class UOWDumper(unitofwork.UOWExecutor): if rec.listonly: continue self.header("Save elements"+ self._inheritance_tag(task)) - self.buf.write(self._indent() + "- " + self._repr_task_element(rec) + "\n") + self.buf.write(self._indent()[:-1] + "+-" + self._repr_task_element(rec) + "\n") self.closeheader() def delete_objects(self, trans, task): @@ -82,10 +77,8 @@ class UOWDumper(unitofwork.UOWExecutor): def _inheritance_tag(self, task): if not self.verbose: return "" - elif task is not self.starttask: - return (" (inheriting task %s)" % self._repr_task(task)) else: - return "" + return (" (inheriting task %s)" % self._repr_task(task)) def header(self, text): """Write a given header just once.""" @@ -115,11 +108,6 @@ class UOWDumper(unitofwork.UOWExecutor): def execute_dependencies(self, trans, task, isdelete=None): super(UOWDumper, self).execute_dependencies(trans, task, isdelete) - def execute_childtasks(self, trans, task, isdelete=None): - self.header("Child tasks" + self._inheritance_tag(task)) - super(UOWDumper, self).execute_childtasks(trans, task, isdelete) - self.closeheader() - def execute_cyclical_dependencies(self, trans, task, isdelete): self.header("Cyclical %s dependencies" % (isdelete and "delete" or "save")) super(UOWDumper, self).execute_cyclical_dependencies(trans, task, isdelete) @@ -140,14 +128,14 @@ class UOWDumper(unitofwork.UOWExecutor): val = proc.targettask.polymorphic_tosave_elements if self.verbose: - self.buf.write(self._indent() + " |- %s attribute on %s (UOWDependencyProcessor(%d) processing %s)\n" % ( + self.buf.write(self._indent() + " +- %s attribute on %s (UOWDependencyProcessor(%d) processing %s)\n" % ( repr(proc.processor.key), ("%s's to be %s" % (self._repr_task_class(proc.targettask), deletes and "deleted" or "saved")), hex(id(proc)), self._repr_task(proc.targettask)) ) elif False: - self.buf.write(self._indent() + " |- %s attribute on %s\n" % ( + self.buf.write(self._indent() + " +- %s attribute on %s\n" % ( repr(proc.processor.key), ("%s's to be %s" % (self._repr_task_class(proc.targettask), deletes and "deleted" or "saved")), ) @@ -155,18 +143,18 @@ class UOWDumper(unitofwork.UOWExecutor): if len(val) == 0: if self.verbose: - self.buf.write(self._indent() + " |- " + "(no objects)\n") + self.buf.write(self._indent() + " +- " + "(no objects)\n") for v in val: - self.buf.write(self._indent() + " |- " + self._repr_task_element(v, proc.processor.key, process=True) + "\n") + self.buf.write(self._indent() + " +- " + self._repr_task_element(v, proc.processor.key, process=True) + "\n") def _repr_task_element(self, te, attribute=None, process=False): - if te.obj is None: + if getattr(te, 'state', None) is None: objid = "(placeholder)" else: if attribute is not None: - objid = "%s.%s" % (mapperutil.instance_str(te.obj), attribute) + objid = "%s.%s" % (mapperutil.state_str(te.state), attribute) else: - objid = mapperutil.instance_str(te.obj) + objid = mapperutil.state_str(te.state) if self.verbose: return "%s (UOWTaskElement(%s, %s))" % (objid, hex(id(te)), (te.listonly and 'listonly' or (te.isdelete and 'delete' or 'save'))) elif process: @@ -182,7 +170,11 @@ class UOWDumper(unitofwork.UOWExecutor): name = repr(task.mapper) else: name = '(none)' - return ("UOWTask(%s, %s)" % (hex(id(task)), name)) + sd = getattr(task, '_superduper', False) + if sd: + return ("SD UOWTask(%s, %s)" % (hex(id(task)), name)) + else: + return ("UOWTask(%s, %s)" % (hex(id(task)), name)) def _repr_task_class(self, task): if task.mapper is not None and task.mapper.__class__.__name__ == 'Mapper': diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index f2b92000b2..d2782ec0ab 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -279,6 +279,11 @@ def instance_str(instance): return instance.__class__.__name__ + "@" + hex(id(instance)) +def state_str(state): + """Return a string describing an instance.""" + + return state.class_.__name__ + "@" + hex(id(state.obj())) + def attribute_str(instance, attribute): return instance_str(instance) + "." + attribute diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 94950b872c..3af8f97cab 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -459,10 +459,7 @@ class DefaultCompiler(engine.Compiled): stack_entry = {'select':select} - if asfrom: - stack_entry['is_subquery'] = True - column_clause_args = {} - elif self.stack and 'select' in self.stack[-1]: + if asfrom or (self.stack and 'select' in self.stack[-1]): stack_entry['is_subquery'] = True column_clause_args = {} else: @@ -546,6 +543,7 @@ class DefaultCompiler(engine.Compiled): def get_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just before column list.""" + return select._distinct and "DISTINCT " or "" def order_by_clause(self, select): @@ -624,8 +622,8 @@ class DefaultCompiler(engine.Compiled): self.binds[col.key] = bindparam return self.bindparam_string(self._truncate_bindparam(bindparam)) - self.postfetch = util.Set() - self.prefetch = util.Set() + self.postfetch = [] + self.prefetch = [] # no parameters in the statement, no parameters in the # compiled params - return binds for all columns @@ -651,7 +649,7 @@ class DefaultCompiler(engine.Compiled): if sql._is_literal(value): value = create_bind_param(c, value) else: - self.postfetch.add(c) + self.postfetch.append(c) value = self.process(value.self_group()) values.append((c, value)) elif isinstance(c, schema.Column): @@ -663,35 +661,35 @@ class DefaultCompiler(engine.Compiled): (c.default is not None and not isinstance(c.default, schema.Sequence))): values.append((c, create_bind_param(c, None))) - self.prefetch.add(c) + self.prefetch.append(c) elif isinstance(c.default, schema.ColumnDefault): if isinstance(c.default.arg, sql.ClauseElement): values.append((c, self.process(c.default.arg.self_group()))) if not c.primary_key: # dont add primary key column to postfetch - self.postfetch.add(c) + self.postfetch.append(c) else: values.append((c, create_bind_param(c, None))) - self.prefetch.add(c) + self.prefetch.append(c) elif isinstance(c.default, schema.PassiveDefault): if not c.primary_key: - self.postfetch.add(c) + self.postfetch.append(c) elif isinstance(c.default, schema.Sequence): proc = self.process(c.default) if proc is not None: values.append((c, proc)) if not c.primary_key: - self.postfetch.add(c) + self.postfetch.append(c) elif self.isupdate: if isinstance(c.onupdate, schema.ColumnDefault): if isinstance(c.onupdate.arg, sql.ClauseElement): values.append((c, self.process(c.onupdate.arg.self_group()))) - self.postfetch.add(c) + self.postfetch.append(c) else: values.append((c, create_bind_param(c, None))) - self.prefetch.add(c) + self.prefetch.append(c) elif isinstance(c.onupdate, schema.PassiveDefault): - self.postfetch.add(c) + self.postfetch.append(c) return values def visit_delete(self, delete_stmt): diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index de5797059c..5aa985f472 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -16,7 +16,7 @@ def sort_tables(tables, reverse=False): vis = TVisitor() for table in tables: vis.traverse(table) - sequence = topological.QueueDependencySorter( tuples, tables).sort(ignore_self_cycles=True, create_tree=False) + sequence = topological.sort(tuples, tables) if reverse: return util.reversed(sequence) else: diff --git a/lib/sqlalchemy/topological.py b/lib/sqlalchemy/topological.py index 258c3bf741..209d3455c4 100644 --- a/lib/sqlalchemy/topological.py +++ b/lib/sqlalchemy/topological.py @@ -21,14 +21,47 @@ conditions. from sqlalchemy import util from sqlalchemy.exceptions import CircularDependencyError -class _Node(object): - """Represent each item in the sort. +__all__ = ['sort', 'sort_with_cycles', 'sort_as_tree'] - While the topological sort produces a straight ordered list of - items, ``_Node`` ultimately stores a tree-structure of those items - which are organized so that non-dependent nodes are siblings. +def sort(tuples, allitems): + """sort the given list of items by dependency. + + 'tuples' is a list of tuples representing a partial ordering. + """ + + return [n.item for n in _sort(tuples, allitems, allow_cycles=False, ignore_self_cycles=True)] + +def sort_with_cycles(tuples, allitems): + """sort the given list of items by dependency, cutting out cycles. + + returns results as an iterable of 2-tuples, containing the item, + and a list containing items involved in a cycle with this item, if any. + + 'tuples' is a list of tuples representing a partial ordering. + """ + + return [(n.item, [n.item for n in n.cycles or []]) for n in _sort(tuples, allitems, allow_cycles=True)] + +def sort_as_tree(tuples, allitems, with_cycles=False): + """sort the given list of items by dependency, and return results + as a hierarchical tree structure. + + returns results as an iterable of 3-tuples, containing the item, + and a list containing items involved in a cycle with this item, if any, + and a list of child tuples. + + if with_cycles is False, the returned structure is of the same form + but the second element of each tuple, i.e. the 'cycles', is an empty list. + + 'tuples' is a list of tuples representing a partial ordering. """ + return _organize_as_tree(_sort(tuples, allitems, allow_cycles=with_cycles)) + + +class _Node(object): + """Represent each item in the sort.""" + def __init__(self, item): self.item = item self.dependencies = util.Set() @@ -37,7 +70,7 @@ class _Node(object): def __str__(self): return self.safestr() - + def safestr(self, indent=0): return (' ' * indent * 2) + \ str(self.item) + \ @@ -130,168 +163,145 @@ class _EdgeCollection(object): def __repr__(self): return repr(list(self)) -class QueueDependencySorter(object): - """Topological sort adapted from wikipedia's article on the subject. - - It creates a straight-line list of elements, then a second pass - batches non-dependent elements as siblings in a tree structure. Future - versions of this algorithm may separate the "convert to a tree" - step. +def _sort(tuples, allitems, allow_cycles=False, ignore_self_cycles=False): + nodes = {} + edges = _EdgeCollection() + for item in list(allitems) + [t[0] for t in tuples] + [t[1] for t in tuples]: + if id(item) not in nodes: + node = _Node(item) + nodes[item] = node + + for t in tuples: + if t[0] is t[1]: + if allow_cycles: + n = nodes[t[0]] + n.cycles = util.Set([n]) + elif not ignore_self_cycles: + raise CircularDependencyError("Self-referential dependency detected " + repr(t)) + continue + childnode = nodes[t[1]] + parentnode = nodes[t[0]] + edges.add((parentnode, childnode)) + + queue = [] + for n in nodes.values(): + if not edges.has_parents(n): + queue.append(n) + + output = [] + while nodes: + if not queue: + # edges remain but no edgeless nodes to remove; this indicates + # a cycle + if allow_cycles: + for cycle in _find_cycles(edges): + lead = cycle[0][0] + lead.cycles = util.Set() + for edge in cycle: + n = edges.remove(edge) + lead.cycles.add(edge[0]) + lead.cycles.add(edge[1]) + if n is not None: + queue.append(n) + for n in lead.cycles: + if n is not lead: + n._cyclical = True + for (n,k) in list(edges.edges_by_parent(n)): + edges.add((lead, k)) + edges.remove((n,k)) + continue + else: + # long cycles not allowed + raise CircularDependencyError("Circular dependency detected " + repr(edges) + repr(queue)) + node = queue.pop() + if not hasattr(node, '_cyclical'): + output.append(node) + del nodes[node.item] + for childnode in edges.pop_node(node): + queue.append(childnode) + return output + +def _organize_as_tree(nodes): + """Given a list of nodes from a topological sort, organize the + nodes into a tree structure, with as many non-dependent nodes + set as siblings to each other as possible. + + returns nodes as 3-tuples (item, cycles, children). """ - def __init__(self, tuples, allitems): - self.tuples = tuples - self.allitems = allitems - - def sort(self, allow_cycles=False, ignore_self_cycles=False, create_tree=True): - (tuples, allitems) = (self.tuples, self.allitems) - #print "\n---------------------------------\n" - #print repr([t for t in tuples]) - #print repr([a for a in allitems]) - #print "\n---------------------------------\n" - - nodes = {} - edges = _EdgeCollection() - for item in list(allitems) + [t[0] for t in tuples] + [t[1] for t in tuples]: - if id(item) not in nodes: - node = _Node(item) - nodes[id(item)] = node - - for t in tuples: - if t[0] is t[1]: - if allow_cycles: - n = nodes[id(t[0])] - n.cycles = util.Set([n]) - elif not ignore_self_cycles: - raise CircularDependencyError("Self-referential dependency detected " + repr(t)) - continue - childnode = nodes[id(t[1])] - parentnode = nodes[id(t[0])] - edges.add((parentnode, childnode)) - - queue = [] - for n in nodes.values(): - if not edges.has_parents(n): - queue.append(n) - - output = [] - while nodes: - if not queue: - # edges remain but no edgeless nodes to remove; this indicates - # a cycle - if allow_cycles: - for cycle in self._find_cycles(edges): - lead = cycle[0][0] - lead.cycles = util.Set() - for edge in cycle: - n = edges.remove(edge) - lead.cycles.add(edge[0]) - lead.cycles.add(edge[1]) - if n is not None: - queue.append(n) - for n in lead.cycles: - if n is not lead: - n._cyclical = True - for (n,k) in list(edges.edges_by_parent(n)): - edges.add((lead, k)) - edges.remove((n,k)) - continue - else: - # long cycles not allowed - raise CircularDependencyError("Circular dependency detected " + repr(edges) + repr(queue)) - node = queue.pop() - if not hasattr(node, '_cyclical'): - output.append(node) - del nodes[id(node.item)] - for childnode in edges.pop_node(node): - queue.append(childnode) - if create_tree: - return self._create_batched_tree(output) - elif allow_cycles: - return output + if not nodes: + return None + # a list of all currently independent subtrees as a tuple of + # (root_node, set_of_all_tree_nodes, set_of_all_cycle_nodes_in_tree) + # order of the list has no semantics for the algorithmic + independents = [] + # in reverse topological order + for node in util.reversed(nodes): + # nodes subtree and cycles contain the node itself + subtree = util.Set([node]) + if node.cycles is not None: + cycles = util.Set(node.cycles) else: - return [n.item for n in output] - - - def _create_batched_tree(self, nodes): - """Given a list of nodes from a topological sort, organize the - nodes into a tree structure, with as many non-dependent nodes - set as siblings to each other as possible. - """ - - if not nodes: - return None - # a list of all currently independent subtrees as a tuple of - # (root_node, set_of_all_tree_nodes, set_of_all_cycle_nodes_in_tree) - # order of the list has no semantics for the algorithmic - independents = [] - # in reverse topological order - for node in util.reversed(nodes): - # nodes subtree and cycles contain the node itself - subtree = util.Set([node]) - if node.cycles is not None: - cycles = util.Set(node.cycles) - else: - cycles = util.Set() - # get a set of dependent nodes of node and its cycles - nodealldeps = node.all_deps() - if nodealldeps: - # iterate over independent node indexes in reverse order so we can efficiently remove them - for index in xrange(len(independents)-1,-1,-1): - child, childsubtree, childcycles = independents[index] - # if there is a dependency between this node and an independent node - if (childsubtree.intersection(nodealldeps) or childcycles.intersection(node.dependencies)): - # prepend child to nodes children - # (append should be fine, but previous implemetation used prepend) - node.children[0:0] = (child,) - # merge childs subtree and cycles - subtree.update(childsubtree) - cycles.update(childcycles) - # remove the child from list of independent subtrees - independents[index:index+1] = [] - # add node as a new independent subtree - independents.append((node,subtree,cycles)) - # choose an arbitrary node from list of all independent subtrees - head = independents.pop()[0] - # add all other independent subtrees as a child of the chosen root - # used prepend [0:0] instead of extend to maintain exact behaviour of previous implementation - head.children[0:0] = [i[0] for i in independents] - return head - - def _find_cycles(self, edges): - involved_in_cycles = util.Set() - cycles = {} - def traverse(node, goal=None, cycle=None): - if goal is None: - goal = node - cycle = [] - elif node is goal: - return True - - for (n, key) in edges.edges_by_parent(node): - if key in cycle: - continue - cycle.append(key) - if traverse(key, goal, cycle): - cycset = util.Set(cycle) - for x in cycle: - involved_in_cycles.add(x) - if x in cycles: - existing_set = cycles[x] - [existing_set.add(y) for y in cycset] - for y in existing_set: - cycles[y] = existing_set - cycset = existing_set - else: - cycles[x] = cycset - cycle.pop() - - for parent in edges.get_parents(): - traverse(parent) - - # sets are not hashable, so uniquify with id - unique_cycles = dict([(id(s), s) for s in cycles.values()]).values() - for cycle in unique_cycles: - edgecollection = [edge for edge in edges - if edge[0] in cycle and edge[1] in cycle] - yield edgecollection + cycles = util.Set() + # get a set of dependent nodes of node and its cycles + nodealldeps = node.all_deps() + if nodealldeps: + # iterate over independent node indexes in reverse order so we can efficiently remove them + for index in xrange(len(independents)-1,-1,-1): + child, childsubtree, childcycles = independents[index] + # if there is a dependency between this node and an independent node + if (childsubtree.intersection(nodealldeps) or childcycles.intersection(node.dependencies)): + # prepend child to nodes children + # (append should be fine, but previous implemetation used prepend) + node.children[0:0] = [(child.item, [child.item for n in child.cycles or []], child.children)] + # merge childs subtree and cycles + subtree.update(childsubtree) + cycles.update(childcycles) + # remove the child from list of independent subtrees + independents[index:index+1] = [] + # add node as a new independent subtree + independents.append((node,subtree,cycles)) + # choose an arbitrary node from list of all independent subtrees + head = independents.pop()[0] + # add all other independent subtrees as a child of the chosen root + # used prepend [0:0] instead of extend to maintain exact behaviour of previous implementation + head.children[0:0] = [(i[0].item, [n.item for n in i[0].cycles or []], i[0].children) for i in independents] + return (head.item, [n.item for n in head.cycles or []], head.children) + +def _find_cycles(edges): + involved_in_cycles = util.Set() + cycles = {} + def traverse(node, goal=None, cycle=None): + if goal is None: + goal = node + cycle = [] + elif node is goal: + return True + + for (n, key) in edges.edges_by_parent(node): + if key in cycle: + continue + cycle.append(key) + if traverse(key, goal, cycle): + cycset = util.Set(cycle) + for x in cycle: + involved_in_cycles.add(x) + if x in cycles: + existing_set = cycles[x] + [existing_set.add(y) for y in cycset] + for y in existing_set: + cycles[y] = existing_set + cycset = existing_set + else: + cycles[x] = cycset + cycle.pop() + + for parent in edges.get_parents(): + traverse(parent) + + # sets are not hashable, so uniquify with id + unique_cycles = dict([(id(s), s) for s in cycles.values()]).values() + for cycle in unique_cycles: + edgecollection = [edge for edge in edges + if edge[0] in cycle and edge[1] in cycle] + yield edgecollection diff --git a/test/base/dependency.py b/test/base/dependency.py index a3d03e2fc5..af5c842880 100644 --- a/test/base/dependency.py +++ b/test/base/dependency.py @@ -4,27 +4,21 @@ from sqlalchemy import util from testlib import * -# TODO: need assertion conditions in this suite - - -class DependencySorter(topological.QueueDependencySorter):pass - - class DependencySortTest(PersistTest): def assert_sort(self, tuples, node, collection=None): print str(node) def assert_tuple(tuple, node): - if node.cycles: - cycles = [i.item for i in node.cycles] + if node[1]: + cycles = node[1] else: cycles = [] - if tuple[0] is node.item or tuple[0] in cycles: + if tuple[0] is node[0] or tuple[0] in cycles: tuple.pop() - if tuple[0] is node.item or tuple[0] in cycles: + if tuple[0] is node[0] or tuple[0] in cycles: return - elif len(tuple) > 1 and tuple[1] is node.item: + elif len(tuple) > 1 and tuple[1] is node[0]: assert False, "Tuple not in dependency tree: " + str(tuple) - for c in node.children: + for c in node[2]: assert_tuple(tuple, c) for tuple in tuples: @@ -34,12 +28,12 @@ class DependencySortTest(PersistTest): collection = [] items = util.Set() def assert_unique(node): - for item in [n.item for n in node.cycles or [node,]]: + for item in [i for i in node[1] or [node[0]]]: assert item not in items items.add(item) if item in collection: collection.remove(item) - for c in node.children: + for c in node[2]: assert_unique(c) assert_unique(node) assert len(collection) == 0 @@ -64,7 +58,7 @@ class DependencySortTest(PersistTest): (node4, subnode3), (node4, subnode4) ] - head = DependencySorter(tuples, []).sort() + head = topological.sort_as_tree(tuples, []) self.assert_sort(tuples, head) def testsort2(self): @@ -82,7 +76,7 @@ class DependencySortTest(PersistTest): (node5, node6), (node6, node2) ] - head = DependencySorter(tuples, [node7]).sort() + head = topological.sort_as_tree(tuples, [node7]) self.assert_sort(tuples, head, [node7]) def testsort3(self): @@ -95,9 +89,9 @@ class DependencySortTest(PersistTest): (node3, node2), (node1,node3) ] - head1 = DependencySorter(tuples, [node1, node2, node3]).sort() - head2 = DependencySorter(tuples, [node3, node1, node2]).sort() - head3 = DependencySorter(tuples, [node3, node2, node1]).sort() + head1 = topological.sort_as_tree(tuples, [node1, node2, node3]) + head2 = topological.sort_as_tree(tuples, [node3, node1, node2]) + head3 = topological.sort_as_tree(tuples, [node3, node2, node1]) # TODO: figure out a "node == node2" function #self.assert_(str(head1) == str(head2) == str(head3)) @@ -116,7 +110,7 @@ class DependencySortTest(PersistTest): (node1, node3), (node3, node2) ] - head = DependencySorter(tuples, []).sort() + head = topological.sort_as_tree(tuples, []) self.assert_sort(tuples, head) def testsort5(self): @@ -139,7 +133,7 @@ class DependencySortTest(PersistTest): node3, node4 ] - head = DependencySorter(tuples, allitems).sort(ignore_self_cycles=True) + head = topological.sort_as_tree(tuples, allitems, with_cycles=True) self.assert_sort(tuples, head) def testcircular(self): @@ -156,7 +150,7 @@ class DependencySortTest(PersistTest): (node3, node1), (node4, node1) ] - head = DependencySorter(tuples, []).sort(allow_cycles=True) + head = topological.sort_as_tree(tuples, [], with_cycles=True) self.assert_sort(tuples, head) def testcircular2(self): @@ -173,20 +167,20 @@ class DependencySortTest(PersistTest): (node3, node2), (node2, node3) ] - head = DependencySorter(tuples, []).sort(allow_cycles=True) + head = topological.sort_as_tree(tuples, [], with_cycles=True) self.assert_sort(tuples, head) def testcircular3(self): nodes = {} tuples = [('Question', 'Issue'), ('ProviderService', 'Issue'), ('Provider', 'Question'), ('Question', 'Provider'), ('ProviderService', 'Question'), ('Provider', 'ProviderService'), ('Question', 'Answer'), ('Issue', 'Question')] - head = DependencySorter(tuples, []).sort(allow_cycles=True) + head = topological.sort_as_tree(tuples, [], with_cycles=True) self.assert_sort(tuples, head) def testbigsort(self): tuples = [] for i in range(0,1500, 2): tuples.append((i, i+1)) - head = DependencySorter(tuples, []).sort() + head = topological.sort_as_tree(tuples, []) diff --git a/test/orm/attributes.py b/test/orm/attributes.py index b321dc50a1..4d56a01ec1 100644 --- a/test/orm/attributes.py +++ b/test/orm/attributes.py @@ -225,15 +225,16 @@ class AttributesTest(PersistTest): attributes.register_attribute(Foo, 'element', uselist=False, useobject=True) x = Bar() x.element = 'this is the element' - hist = attributes.get_history(x, 'element') - assert hist.added_items() == ['this is the element'] + (added, unchanged, deleted) = attributes.get_history(x._state, 'element') + assert added == ['this is the element'] x._state.commit_all() - hist = attributes.get_history(x, 'element') - assert hist.added_items() == [] - assert hist.unchanged_items() == ['this is the element'] + (added, unchanged, deleted) = attributes.get_history(x._state, 'element') + assert added == [] + assert unchanged == ['this is the element'] def test_lazyhistory(self): """tests that history functions work with lazy-loading attributes""" + class Foo(object):pass class Bar(object): def __init__(self, id): @@ -257,16 +258,13 @@ class AttributesTest(PersistTest): x = Foo() x._state.commit_all() x.col2.append(Bar(4)) - h = attributes.get_history(x, 'col2') - print h.added_items() - print h.unchanged_items() + (added, unchanged, deleted) = attributes.get_history(x._state, 'col2') def test_parenttrack(self): class Foo(object):pass class Bar(object):pass - attributes.register_class(Foo) attributes.register_class(Bar) @@ -299,7 +297,7 @@ class AttributesTest(PersistTest): x.element = ['one', 'two', 'three'] x._state.commit_all() x.element[1] = 'five' - assert attributes.is_modified(x) + assert x._state.is_modified() attributes.unregister_class(Foo) @@ -309,7 +307,7 @@ class AttributesTest(PersistTest): x.element = ['one', 'two', 'three'] x._state.commit_all() x.element[1] = 'five' - assert not attributes.is_modified(x) + assert not x._state.is_modified() def test_descriptorattributes(self): """changeset: 1633 broke ability to use ORM to map classes with unusual diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py index 5affaa238f..05603ac864 100644 --- a/test/orm/inheritance/basic.py +++ b/test/orm/inheritance/basic.py @@ -56,10 +56,11 @@ class O2MTest(ORMTest): b1.parent_foo = f b2.parent_foo = f sess.flush() - compare = repr(b1) + repr(b2) + repr(b1.parent_foo) + repr(b2.parent_foo) + compare = ','.join([repr(b1), repr(b2), repr(b1.parent_foo), repr(b2.parent_foo)]) sess.clear() l = sess.query(Blub).select() - result = repr(l[0]) + repr(l[1]) + repr(l[0].parent_foo) + repr(l[1].parent_foo) + result = ','.join([repr(l[0]), repr(l[1]), repr(l[0].parent_foo), repr(l[1].parent_foo)]) + print compare print result self.assert_(compare == result) self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1') diff --git a/test/orm/session.py b/test/orm/session.py index a662371480..7bd8b666d7 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -634,6 +634,7 @@ class SessionTest(AssertMixin): assert u in s assert a not in s s.flush() + print "\n".join([repr(x.__dict__) for x in s]) s.clear() assert s.query(User).one().user_id == u.user_id assert s.query(Address).first() is None diff --git a/test/sql/defaults.py b/test/sql/defaults.py index c385b0ac6b..a50250e9b4 100644 --- a/test/sql/defaults.py +++ b/test/sql/defaults.py @@ -136,11 +136,11 @@ class DefaultTest(PersistTest): def testinsert(self): r = t.insert().execute() assert r.lastrow_has_defaults() - assert util.Set(r.context.postfetch_cols()) == util.Set([t.c.col3, t.c.col5, t.c.col4, t.c.col6]) + assert util.Set(r.context.postfetch_cols) == util.Set([t.c.col3, t.c.col5, t.c.col4, t.c.col6]) r = t.insert(inline=True).execute() assert r.lastrow_has_defaults() - assert util.Set(r.context.postfetch_cols()) == util.Set([t.c.col3, t.c.col5, t.c.col4, t.c.col6]) + assert util.Set(r.context.postfetch_cols) == util.Set([t.c.col3, t.c.col5, t.c.col4, t.c.col6]) t.insert().execute() t.insert().execute()