From: Mike Bayer Date: Fri, 14 Dec 2007 05:53:18 +0000 (+0000) Subject: - merged instances_yields branch r3908:3934, minus the "yield" part which remains... X-Git-Tag: rel_0_4_2~53 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0df750223a5f6ee4cfa987a4abd5ab4691007350;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - merged instances_yields branch r3908:3934, minus the "yield" part which remains slightly problematic - cleanup of mapper._instance, query.instances(). mapper identifies objects which are part of the current load using a app-unique id on the query context. - attributes refactor; attributes now mostly use copy-on-modify instead of copy-on-load behavior, simplified get_history(), added a new set of tests - fixes to OrderedSet such that difference(), intersection() and others can accept an iterator - OrderedIdentitySet passes in OrderedSet to the IdentitySet superclass for usage in difference/intersection/etc. operations so that these methods actually work with ordering behavior. - query.order_by() takes into account aliased joins, i.e. query.join('orders', aliased=True).order_by(Order.id) - cleanup etc. --- diff --git a/CHANGES b/CHANGES index e970b0eb05..a19d5302c9 100644 --- a/CHANGES +++ b/CHANGES @@ -78,10 +78,23 @@ CHANGES new behavior allows not just joins from the main table, but select statements as well. Filter criterion, order bys, eager load clauses will be "aliased" against the given statement. - + + - this month's refactoring of attribute instrumentation changes + the "copy-on-load" behavior we've had since midway through 0.3 + with "copy-on-modify" in most cases. This takes a sizable chunk + of latency out of load operations and overall does less work + as only attributes which are actually modified get their + "committed state" copied. Only "mutable scalar" attributes + (i.e. a pickled object or other mutable item), the reason for + the copy-on-load change in the first place, retain the old + behavior. + - query.filter(SomeClass.somechild == None), when comparing a many-to-one property to None, properly generates "id IS NULL" including that the NULL is on the right side. + + - query.order_by() takes into account aliased joins, i.e. + query.join('orders', aliased=True).order_by(Order.id) - eagerload(), lazyload(), eagerload_all() take an optional second class-or-mapper argument, which will select the mapper diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index a26bc2b58d..af589f3403 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -15,6 +15,7 @@ from sqlalchemy import exceptions PASSIVE_NORESULT = object() ATTR_WAS_SET = object() NO_VALUE = object() +NEVER_SET = object() class InstrumentedAttribute(interfaces.PropComparator): """public-facing instrumented attribute, placed in the @@ -73,29 +74,29 @@ class ProxiedAttribute(InstrumentedAttribute): class ProxyImpl(object): def __init__(self, key): self.key = key - - def commit_to_state(self, state, value=NO_VALUE): - pass - + def __init__(self, key, user_prop, comparator=None): self.user_prop = user_prop self.comparator = comparator self.key = key self.impl = ProxiedAttribute.ProxyImpl(key) + def __get__(self, instance, owner): if instance is None: self.user_prop.__get__(instance, owner) return self return self.user_prop.__get__(instance, owner) + def __set__(self, instance, value): return self.user_prop.__set__(instance, value) + def __delete__(self, instance): return self.user_prop.__delete__(instance) class AttributeImpl(object): """internal implementation for instrumented attributes.""" - def __init__(self, class_, key, callable_, trackparent=False, extension=None, compare_function=None, mutable_scalars=False, **kwargs): + def __init__(self, class_, key, callable_, trackparent=False, extension=None, compare_function=None, **kwargs): """Construct an AttributeImpl. class_ @@ -122,37 +123,18 @@ class AttributeImpl(object): a function that compares two values which are normally assignable to this attribute. - mutable_scalars - if True, the values which are normally assignable to this - attribute can mutate, and need to be compared against a copy of - their original contents in order to detect changes on the parent - instance. """ self.class_ = class_ self.key = key self.callable_ = callable_ self.trackparent = trackparent - self.mutable_scalars = mutable_scalars - if mutable_scalars: - class_._class_state.has_mutable_scalars = True - self.copy = None if compare_function is None: self.is_equal = operator.eq else: self.is_equal = compare_function self.extensions = util.to_list(extension or []) - def commit_to_state(self, state, value=NO_VALUE): - """Commits the object's current state to its 'committed' state.""" - - if value is NO_VALUE: - if self.key in state.dict: - value = state.dict[self.key] - if value is not NO_VALUE: - state.committed_state[self.key] = self.copy(value) - state.pending.pop(self.key, None) - def hasparent(self, state, optimistic=False): """Return the boolean value of a `hasparent` flag attached to the given item. @@ -178,14 +160,7 @@ class AttributeImpl(object): state.parents[id(self)] = value - def get_history(self, state, passive=False): - current = self.get(state, passive=passive) - if current is PASSIVE_NORESULT: - return (None, None, None) - else: - return _create_history(self, state, current) - - def set_callable(self, state, callable_, clear=False): + def set_callable(self, state, callable_): """Set a callable function for this attribute on the given object. This callable will be executed when the attribute is next @@ -200,14 +175,14 @@ class AttributeImpl(object): ``InstrumentedAttribute` constructor. """ - if clear: - self.clear(state) - if callable_ is None: self.initialize(state) else: state.callables[self.key] = callable_ + def get_history(self, state, passive=False): + raise NotImplementedError() + def _get_callable(self, state): if self.key in state.callables: return state.callables[self.key] @@ -216,9 +191,6 @@ class AttributeImpl(object): else: return None - def check_mutable_modified(self, state): - return False - def initialize(self, state): """Initialize this attribute on the given object instance with an empty value.""" @@ -261,125 +233,117 @@ class AttributeImpl(object): raise NotImplementedError() def get_committed_value(self, state): - if state.committed_state is not None: - if self.key not in state.committed_state: - self.get(state) + """return the unchanged value of this attribute""" + + if self.key in state.committed_state: return state.committed_state.get(self.key) else: - return None + return self.get(state) def set_committed_value(self, state, value): - """set an attribute value on the given instance and 'commit' it. - - this indicates that the given value is the "persisted" value, - and history will be logged only if a newly set value is not - equal to this value. - - this is typically used by deferred/lazy attribute loaders - to set object attributes after the initial load. - """ + """set an attribute value on the given instance and 'commit' it.""" if state.committed_state is not None: - self.commit_to_state(state, value) + state.commit_attr(self, value) # remove per-instance callable, if any - state.callables.pop(self, None) + state.callables.pop(self.key, None) state.dict[self.key] = value return value - def fire_append_event(self, state, value, initiator): - state.modified = True - if self.trackparent and value is not None: - self.sethasparent(value._state, True) - instance = state.obj() - for ext in self.extensions: - ext.append(instance, value, initiator or self) +class ScalarAttributeImpl(AttributeImpl): + """represents a scalar value-holding InstrumentedAttribute.""" - def fire_remove_event(self, state, value, initiator): - state.modified = True - if self.trackparent and value is not None: - self.sethasparent(value._state, False) - instance = state.obj() - for ext in self.extensions: - ext.remove(instance, value, initiator or self) + accepts_global_callable = True + + def delete(self, state): + del state.dict[self.key] + state.modified=True - def fire_replace_event(self, state, value, previous, initiator): - state.modified = True - if self.trackparent: - if value is not None: - self.sethasparent(value._state, True) - if previous is not None: - self.sethasparent(previous._state, False) - instance = state.obj() - for ext in self.extensions: - ext.set(instance, value, previous, initiator or self) + def get_history(self, state, passive=False): + return _create_history(self, state, state.dict.get(self.key, NO_VALUE)) -class ScalarAttributeImpl(AttributeImpl): - """represents a scalar value-holding InstrumentedAttribute.""" - def __init__(self, class_, key, callable_, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs): - super(ScalarAttributeImpl, self).__init__(class_, key, - callable_, compare_function=compare_function, mutable_scalars=mutable_scalars, **kwargs) + def set(self, state, value, initiator): + if initiator is self: + return + + if self.key not in state.committed_state: + state.committed_state[self.key] = state.dict.get(self.key, NO_VALUE) + state.dict[self.key] = value + state.modified=True + + type = property(lambda self: self.property.columns[0].type) + +class MutableScalarAttributeImpl(ScalarAttributeImpl): + """represents a scalar value-holding InstrumentedAttribute, which can detect + changes within the value itself. + """ + + def __init__(self, class_, key, callable_, copy_function=None, compare_function=None, **kwargs): + super(ScalarAttributeImpl, self).__init__(class_, key, callable_, compare_function=compare_function, **kwargs) + class_._class_state.has_mutable_scalars = True if copy_function is None: - copy_function = self.__copy + raise exceptions.ArgumentError("MutableScalarAttributeImpl requires a copy function") self.copy = copy_function - self.accepts_global_callable = True - - def __copy(self, item): - # scalar values are assumed to be immutable unless a copy function - # is passed - return item - def delete(self, state): - del state.dict[self.key] - state.modified=True + def get_history(self, state, passive=False): + return _create_history(self, state, state.dict.get(self.key, NO_VALUE)) + + def commit_to_state(self, state, value): + state.committed_state[self.key] = self.copy(value) def check_mutable_modified(self, state): - if self.mutable_scalars: - (added, unchanged, deleted) = self.get_history(state, passive=True) - if added or deleted: - state.modified = True - return True - else: - return False + (added, unchanged, deleted) = self.get_history(state, passive=True) + if added or deleted: + state.modified = True + return True else: return False def set(self, state, value, initiator): - """Set a value on the given InstanceState. - - `initiator` is the ``InstrumentedAttribute`` that initiated the - ``set()` operation and is used to control the depth of a circular - setter operation. - """ - if initiator is self: return + if self.key not in state.committed_state: + if self.key in state.dict: + state.committed_state[self.key] = self.copy(state.dict[self.key]) + else: + state.committed_state[self.key] = NO_VALUE + state.dict[self.key] = value state.modified=True - type = property(lambda self: self.property.columns[0].type) - class ScalarObjectAttributeImpl(ScalarAttributeImpl): - """represents a scalar class-instance holding InstrumentedAttribute. + """represents a scalar-holding InstrumentedAttribute, where the target object is also instrumented. Adds events to delete/set operations. """ - - def __init__(self, class_, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs): + + accepts_global_callable = False + + def __init__(self, class_, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs): super(ScalarObjectAttributeImpl, self).__init__(class_, key, callable_, trackparent=trackparent, extension=extension, - compare_function=compare_function, mutable_scalars=mutable_scalars, **kwargs) + compare_function=compare_function, **kwargs) if compare_function is None: self.is_equal = identity_equal - self.accepts_global_callable = False - + def delete(self, state): old = self.get(state) del state.dict[self.key] self.fire_remove_event(state, old, self) + def get_history(self, state, passive=False): + if self.key in state.dict: + return _create_history(self, state, state.dict[self.key]) + else: + current = self.get(state, passive=passive) + if current is PASSIVE_NORESULT: + return (None, None, None) + else: + return _create_history(self, state, current) + def set(self, state, value, initiator): """Set a value on the given InstanceState. @@ -391,19 +355,49 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): if initiator is self: return + # TODO: add options to allow the get() to be passive old = self.get(state) state.dict[self.key] = value self.fire_replace_event(state, value, old, initiator) - + def fire_remove_event(self, state, value, initiator): + if self.key not in state.committed_state: + state.committed_state[self.key] = value + state.modified = True + + if self.trackparent and value is not None: + self.sethasparent(value._state, False) + + instance = state.obj() + for ext in self.extensions: + ext.remove(instance, value, initiator or self) + + def fire_replace_event(self, state, value, previous, initiator): + if self.key not in state.committed_state: + state.committed_state[self.key] = previous + state.modified = True + + if self.trackparent: + if value is not None: + self.sethasparent(value._state, True) + if previous is not None: + self.sethasparent(previous._state, False) + + instance = state.obj() + for ext in self.extensions: + ext.set(instance, value, previous, initiator or self) + class CollectionAttributeImpl(AttributeImpl): """A collection-holding attribute that instruments changes in membership. + Only handles collections of instrumented objects. + InstrumentedCollectionAttribute holds an arbitrary, user-specified container object (defaulting to a list) and brokers access to the CollectionAdapter, a "view" onto that object that presents consistent bag semantics to the orm layer independent of the user data implementation. """ + accepts_global_callable = False def __init__(self, class_, key, callable_, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs): super(CollectionAttributeImpl, self).__init__(class_, @@ -414,8 +408,6 @@ class CollectionAttributeImpl(AttributeImpl): copy_function = self.__copy self.copy = copy_function - self.accepts_global_callable = False - if typecallable is None: typecallable = list self.collection_factory = \ @@ -426,6 +418,59 @@ class CollectionAttributeImpl(AttributeImpl): def __copy(self, item): return [y for y in list(collections.collection_adapter(item))] + def get_history(self, state, passive=False): + if self.key in state.dict: + return _create_history(self, state, state.dict[self.key]) + else: + current = self.get(state, passive=passive) + if current is PASSIVE_NORESULT: + return (None, None, None) + else: + return _create_history(self, state, current) + + def fire_append_event(self, state, value, initiator): + if self.key not in state.committed_state: + if self.key in state.dict: + state.committed_state[self.key] = self.copy(state.dict[self.key]) + else: + state.committed_state[self.key] = NO_VALUE + state.modified = True + + if self.trackparent and value is not None: + self.sethasparent(value._state, True) + instance = state.obj() + for ext in self.extensions: + ext.append(instance, value, initiator or self) + + def fire_pre_remove_event(self, state, initiator): + if self.key not in state.committed_state: + if self.key in state.dict: + state.committed_state[self.key] = self.copy(state.dict[self.key]) + else: + state.committed_state[self.key] = NO_VALUE + + def fire_remove_event(self, state, value, initiator): + if self.key not in state.committed_state: + if self.key in state.dict: + state.committed_state[self.key] = self.copy(state.dict[self.key]) + else: + state.committed_state[self.key] = NO_VALUE + state.modified = True + + if self.trackparent and value is not None: + self.sethasparent(value._state, False) + + instance = state.obj() + for ext in self.extensions: + ext.remove(instance, value, initiator or self) + + def get_history(self, state, passive=False): + current = self.get(state, passive=passive) + if current is PASSIVE_NORESULT: + return (None, None, None) + else: + return _create_history(self, state, current) + def delete(self, state): if self.key not in state.dict: return @@ -447,9 +492,15 @@ class CollectionAttributeImpl(AttributeImpl): if initiator is self: return + if self.key not in state.committed_state: + if self.key in state.dict: + state.committed_state[self.key] = self.copy(state.dict[self.key]) + else: + state.committed_state[self.key] = NO_VALUE + collection = self.get_collection(state, passive=passive) if collection is PASSIVE_NORESULT: - state.get_pending(self).append(value) + state.get_pending(self.key).append(value) self.fire_append_event(state, value, initiator) else: collection.append_with_event(value, initiator) @@ -458,9 +509,15 @@ class CollectionAttributeImpl(AttributeImpl): if initiator is self: return + if self.key not in state.committed_state: + if self.key in state.dict: + state.committed_state[self.key] = self.copy(state.dict[self.key]) + else: + state.committed_state[self.key] = NO_VALUE + collection = self.get_collection(state, passive=passive) if collection is PASSIVE_NORESULT: - state.get_pending(self).remove(value) + state.get_pending(self.key).remove(value) self.fire_remove_event(state, value, initiator) else: collection.remove_with_event(value, initiator) @@ -476,6 +533,12 @@ class CollectionAttributeImpl(AttributeImpl): if initiator is self: return + if self.key not in state.committed_state: + if self.key in state.dict: + state.committed_state[self.key] = self.copy(state.dict[self.key]) + else: + state.committed_state[self.key] = NO_VALUE + # we need a CollectionAdapter to adapt the incoming value to an # assignable iterable. pulling a new collection first so that # an adaptation exception does not trigger a lazy load of the @@ -518,9 +581,9 @@ class CollectionAttributeImpl(AttributeImpl): value = user_data if state.committed_state is not None: - self.commit_to_state(state, value) + state.commit_attr(self, value) # remove per-instance callable, if any - state.callables.pop(self, None) + state.callables.pop(self.key, None) state.dict[self.key] = value return value @@ -618,12 +681,14 @@ class InstanceState(object): self.obj = weakref.ref(obj, self.__cleanup) self.dict = obj.__dict__ self.committed_state = {} - self.modified = self.strong = False + self.modified = False self.trigger = None self.callables = {} self.parents = {} self.pending = {} + self.appenders = {} self.instance_dict = None + self.runid = None def __cleanup(self, ref): # tiptoe around Python GC unpredictableness @@ -662,17 +727,17 @@ class InstanceState(object): finally: instance_dict._mutex.release() - def get_pending(self, attributeimpl): - if attributeimpl.key not in self.pending: - self.pending[attributeimpl.key] = PendingCollection() - return self.pending[attributeimpl.key] + def get_pending(self, key): + if key not in self.pending: + self.pending[key] = PendingCollection() + return self.pending[key] def is_modified(self): if self.modified: return True elif self.class_._class_state.has_mutable_scalars: for attr in _managed_attributes(self.class_): - if getattr(attr.impl, 'mutable_scalars', False) and attr.impl.check_mutable_modified(self): + if hasattr(attr.impl, 'check_mutable_modified') and attr.impl.check_mutable_modified(self): return True else: return False @@ -680,7 +745,7 @@ class InstanceState(object): return False def __resurrect(self, instance_dict): - if self.strong or self.is_modified(): + if self.is_modified(): # store strong ref'ed version of the object; will revert # to weakref when changes are persisted obj = new_instance(self.class_, state=self) @@ -745,6 +810,14 @@ class InstanceState(object): self.dict.pop(key, None) self.callables.pop(key, None) + def commit_attr(self, attr, value): + if hasattr(attr, 'commit_to_state'): + attr.commit_to_state(self, value) + else: + self.committed_state.pop(attr.key, None) + self.pending.pop(attr.key, None) + self.appenders.pop(attr.key, None) + def commit(self, keys): """commit all attributes named in the given list of key names. @@ -752,8 +825,20 @@ class InstanceState(object): which were refreshed from the database. """ - for key in keys: - getattr(self.class_, key).impl.commit_to_state(self) + if self.class_._class_state.has_mutable_scalars: + for key in keys: + attr = getattr(self.class_, key).impl + if hasattr(attr, 'commit_to_state') and attr.key in self.dict: + attr.commit_to_state(self, self.dict[attr.key]) + else: + self.committed_state.pop(attr.key, None) + self.pending.pop(key, None) + self.appenders.pop(key, None) + else: + for key in keys: + self.committed_state.pop(key, None) + self.pending.pop(key, None) + self.appenders.pop(key, None) def commit_all(self): """commit all attributes unconditionally. @@ -764,8 +849,14 @@ class InstanceState(object): self.committed_state = {} self.modified = False - for attr in _managed_attributes(self.class_): - attr.impl.commit_to_state(self) + self.pending = {} + self.appenders = {} + + if self.class_._class_state.has_mutable_scalars: + for attr in _managed_attributes(self.class_): + if hasattr(attr.impl, 'commit_to_state') and attr.impl.key in self.dict: + attr.impl.commit_to_state(self, self.dict[attr.impl.key]) + # remove strong ref self._strong_obj = None @@ -894,43 +985,30 @@ class StrongInstanceDict(dict): return [o._state for o in self.values()] def _create_history(attr, state, current): - if state.committed_state: - original = state.committed_state.get(attr.key, NO_VALUE) - else: - original = NO_VALUE + original = state.committed_state.get(attr.key, NEVER_SET) if hasattr(attr, 'get_collection'): if original is NO_VALUE: - s = util.IdentitySet([]) + return (list(current), [], []) + elif original is NEVER_SET: + return ([], list(current), []) else: - s = util.IdentitySet(original) - - _added_items = [] - _unchanged_items = [] - _deleted_items = [] - if current: - collection = attr.get_collection(state, current) - for a in collection: - if a in s: - _unchanged_items.append(a) - else: - _added_items.append(a) - _deleted_items = list(s.difference(_unchanged_items)) - - return (_added_items, _unchanged_items, _deleted_items) + collection = util.OrderedIdentitySet(attr.get_collection(state, current)) + s = util.OrderedIdentitySet(original) + return (list(collection.difference(s)), list(collection.intersection(s)), list(s.difference(collection))) else: - if attr.is_equal(current, original) is True: - _unchanged_items = [current] - _added_items = [] - _deleted_items = [] + if current is NO_VALUE: + return ([], [], []) + elif original is NO_VALUE: + return ([current], [], []) + elif original is NEVER_SET or attr.is_equal(current, original) is True: # dont let ClauseElement expressions here trip things up + return ([], [current], []) else: - _added_items = [current] - if original is not NO_VALUE and original is not None: - _deleted_items = [original] + if original is not None: + deleted = [original] else: - _deleted_items = [] - _unchanged_items = [] - return (_added_items, _unchanged_items, _deleted_items) + deleted = [] + return ([current], [], deleted) class PendingCollection(object): """stores items appended and removed from a collection that has not been loaded yet. @@ -960,7 +1038,6 @@ def _managed_attributes(class_): def get_history(state, key, **kwargs): return getattr(state.class_, key).impl.get_history(state, **kwargs) -get_state_history = get_history def get_as_list(state, key, passive=False): """return an InstanceState attribute as a list, @@ -985,18 +1062,18 @@ def get_as_list(state, key, passive=False): def has_parent(class_, instance, key, optimistic=False): return getattr(class_, key).impl.hasparent(instance._state, optimistic=optimistic) -def _create_prop(class_, key, uselist, callable_, typecallable, useobject, **kwargs): +def _create_prop(class_, key, uselist, callable_, typecallable, useobject, mutable_scalars, **kwargs): if kwargs.pop('dynamic', False): from sqlalchemy.orm import dynamic return dynamic.DynamicAttributeImpl(class_, key, typecallable, **kwargs) elif uselist: return CollectionAttributeImpl(class_, key, callable_, typecallable, **kwargs) elif useobject: - return ScalarObjectAttributeImpl(class_, key, callable_, - **kwargs) + return ScalarObjectAttributeImpl(class_, key, callable_,**kwargs) + elif mutable_scalars: + return MutableScalarAttributeImpl(class_, key, callable_, **kwargs) else: - return ScalarAttributeImpl(class_, key, callable_, - **kwargs) + return ScalarAttributeImpl(class_, key, callable_, **kwargs) def manage(instance): """initialize an InstanceState on the given instance.""" @@ -1081,7 +1158,7 @@ def unregister_class(class_): delattr(class_, attr.impl.key) delattr(class_, '_class_state') -def register_attribute(class_, key, uselist, useobject, callable_=None, proxy_property=None, **kwargs): +def register_attribute(class_, key, uselist, useobject, callable_=None, proxy_property=None, mutable_scalars=False, **kwargs): _init_class_state(class_) typecallable = kwargs.pop('typecallable', None) @@ -1099,7 +1176,7 @@ def register_attribute(class_, key, uselist, useobject, callable_=None, proxy_pr inst = ProxiedAttribute(key, proxy_property, comparator=comparator) else: inst = InstrumentedAttribute(_create_prop(class_, key, uselist, callable_, useobject=useobject, - typecallable=typecallable, **kwargs), comparator=comparator) + typecallable=typecallable, mutable_scalars=mutable_scalars, **kwargs), comparator=comparator) setattr(class_, key, inst) class_._class_state.attrs[key] = inst diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 7334e46642..ddbf6f0051 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -575,7 +575,7 @@ class CollectionAdapter(object): self.attr.fire_append_event(self.owner_state, item, initiator) def fire_remove_event(self, item, initiator=None): - """Notify that a entity has entered the collection. + """Notify that a entity has been removed from the collection. Initiator is the InstrumentedAttribute that initiated the membership mutation, and should be left as None unless you are passing along @@ -585,6 +585,15 @@ class CollectionAdapter(object): if initiator is not False and item is not None: self.attr.fire_remove_event(self.owner_state, item, initiator) + def fire_pre_remove_event(self, initiator=None): + """Notify that an entity is about to be removed from the collection. + + Only called if the entity cannot be removed after calling + fire_remove_event(). + """ + + self.attr.fire_pre_remove_event(self.owner_state, initiator=initiator) + def __getstate__(self): return { 'key': self.attr.key, 'owner_state': self.owner_state, @@ -838,6 +847,13 @@ def __del(collection, item, _sa_initiator=None): if executor: getattr(executor, 'fire_remove_event')(item, _sa_initiator) +def __before_delete(collection, _sa_initiator=None): + """Special method to run 'commit existing value' methods""" + + executor = getattr(collection, '_sa_adapter', None) + if executor: + getattr(executor, 'fire_pre_remove_event')(_sa_initiator) + def _list_decorators(): """Hand-turned instrumentation wrappers that can decorate any list-like class.""" @@ -862,6 +878,7 @@ def _list_decorators(): def remove(fn): def remove(self, value, _sa_initiator=None): # testlib.pragma exempt:__eq__ + __before_delete(self, _sa_initiator) fn(self, value) __del(self, value, _sa_initiator) _tidy(remove) @@ -953,6 +970,7 @@ def _list_decorators(): def pop(fn): def pop(self, index=-1): + __before_delete(self) item = fn(self, index) __del(self, item) return item @@ -1011,6 +1029,7 @@ def _dict_decorators(): def popitem(fn): def popitem(self): + __before_delete(self) item = fn(self) __del(self, item[1]) return item @@ -1098,6 +1117,7 @@ def _set_decorators(): def pop(fn): def pop(self): + __before_delete(self) item = fn(self) __del(self, item) return item diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 63bdaea40a..ea99d65148 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -17,13 +17,27 @@ class DynamicAttributeImpl(attributes.AttributeImpl): else: return AppenderQuery(self, state) - def commit_to_state(self, state, value=attributes.NO_VALUE): - # we have our own AttributeHistory therefore dont need CommittedState - # instead, we reset the history stored on the attribute - state.dict[self.key] = CollectionHistory(self, state) - def get_collection(self, state, user_data=None): return self._get_collection(state, passive=True).added_items + + def fire_append_event(self, state, value, initiator): + state.modified = True + + if self.trackparent and value is not None: + self.sethasparent(value._state, True) + instance = state.obj() + for ext in self.extensions: + ext.append(instance, value, initiator or self) + + def fire_remove_event(self, state, value, initiator): + state.modified = True + + if self.trackparent and value is not None: + self.sethasparent(value._state, False) + + instance = state.obj() + for ext in self.extensions: + ext.remove(instance, value, initiator or self) def set(self, state, value, initiator): if initiator is self: diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 413a1af2cd..6119d1c6e6 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -383,7 +383,7 @@ class MapperProperty(object): level (as opposed to the individual instance level). """ - return self.parent._is_primary_mapper() + return not self.parent.non_primary def merge(self, session, source, dest): """Merge the attribute represented by this ``MapperProperty`` diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index e9fe41fdc7..6ea45a598a 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -158,11 +158,11 @@ class Mapper(object): def __log(self, msg): if self.__should_log_info: - self.logger.info("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (not self._is_primary_mapper() and "|non-primary" or "") + ") " + msg) + self.logger.info("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (not self.non_primary and "|non-primary" or "") + ") " + msg) def __log_debug(self, msg): if self.__should_log_debug: - self.logger.debug("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (not self._is_primary_mapper() and "|non-primary" or "") + ") " + msg) + self.logger.debug("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (not self.non_primary and "|non-primary" or "") + ") " + msg) def _is_orphan(self, obj): optimistic = has_identity(obj) @@ -299,8 +299,8 @@ class Mapper(object): self.inherits = self.inherits if not issubclass(self.class_, self.inherits.class_): raise exceptions.ArgumentError("Class '%s' does not inherit from '%s'" % (self.class_.__name__, self.inherits.class_.__name__)) - if self._is_primary_mapper() != self.inherits._is_primary_mapper(): - np = self._is_primary_mapper() and "primary" or "non-primary" + if self.non_primary != self.inherits.non_primary: + np = not self.non_primary and "primary" or "non-primary" raise exceptions.ArgumentError("Inheritance of %s mapper for class '%s' is only allowed from a %s mapper" % (np, self.class_.__name__, np)) # inherit_condition is optional. if self.local_table is None: @@ -815,36 +815,12 @@ class Mapper(object): self._compile_property(key, prop, init=self.__props_init) def __str__(self): - return "Mapper|" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (not self._is_primary_mapper() and "|non-primary" or "") - - def _is_primary_mapper(self): - """Return True if this mapper is the primary mapper for its class key (class + entity_name).""" - # FIXME: cant we just look at "non_primary" flag ? - return self._class_state.mappers[self.entity_name] is self + return "Mapper|" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (self.non_primary and "|non-primary" or "") def primary_mapper(self): """Return the primary mapper corresponding to this mapper's class key (class + entity_name).""" return self._class_state.mappers[self.entity_name] - def is_assigned(self, instance): - """Return True if this mapper handles the given instance. - - This is dependent not only on class assignment but the - optional `entity_name` parameter as well. - """ - - return instance.__class__ is self.class_ and getattr(instance, '_entity_name', None) == self.entity_name - - def _assign_entity_name(self, instance): - """Assign this Mapper's entity name to the given instance. - - Subsequent Mapper lookups for this instance will return the - primary mapper corresponding to this Mapper's class and entity - name. - """ - - instance._entity_name = self.entity_name - def get_session(self): """Return the contextual session provided by the mapper extension chain, if any. @@ -858,7 +834,7 @@ class Mapper(object): if s is not EXT_CONTINUE: return s - raise exceptions.InvalidRequestError("No contextual Session is established. Use a MapperExtension that implements get_session or use 'import sqlalchemy.mods.threadlocal' to establish a default thread-local contextual session.") + raise exceptions.InvalidRequestError("No contextual Session is established.") def instances(self, cursor, session, *mappers, **kwargs): """Return a list of mapped instances corresponding to the rows @@ -967,16 +943,19 @@ class Mapper(object): self.save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True) return + # if session has a connection callable, + # organize individual states with the connection to use for insert/update if 'connection_callable' in uowtransaction.mapper_flush_opts: connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - tups = [(state, connection_callable(self, state.obj())) for state in states] + tups = [(state, connection_callable(self, state.obj()), _state_has_identity(state)) for state in states] else: connection = uowtransaction.transaction.connection(self) - tups = [(state, connection) for state in states] + tups = [(state, connection, _state_has_identity(state)) for state in states] if not postupdate: - for state, connection in tups: - if not _state_has_identity(state): + # call before_XXX extensions + for state, connection, has_identity in tups: + if not has_identity: for mapper in _state_mapper(state).iterate_to_root(): if 'before_insert' in mapper.extension.methods: mapper.extension.before_insert(mapper, connection, state.obj()) @@ -985,13 +964,13 @@ class Mapper(object): if 'before_update' in mapper.extension.methods: mapper.extension.before_update(mapper, connection, state.obj()) - for state, connection in tups: + for state, connection, has_identity in tups: # detect if we have a "pending" instance (i.e. has no instance_key attached to it), # and another instance with the same identity key already exists as persistent. convert to an # UPDATE if so. mapper = _state_mapper(state) instance_key = mapper._identity_key_from_state(state) - if not postupdate and not _state_has_identity(state) and instance_key in uowtransaction.uow.identity_map: + if not postupdate and not has_identity and instance_key in uowtransaction.uow.identity_map: existing = uowtransaction.uow.identity_map[instance_key] if not uowtransaction.is_deleted(existing): raise exceptions.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (mapperutil.state_str(state), str(instance_key), mapperutil.instance_str(existing))) @@ -1008,11 +987,10 @@ class Mapper(object): table_to_mapper[t] = mapper for table in sqlutil.sort_tables(table_to_mapper.keys()): - # two lists to store parameters for each table/object pair located insert = [] update = [] - for state, connection in tups: + for state, connection, has_identity in tups: mapper = _state_mapper(state) if table not in mapper._pks_by_table: continue @@ -1022,7 +1000,7 @@ class Mapper(object): if self.__should_log_debug: self.__log_debug("save_obj() table '%s' instance %s identity %s" % (table.name, mapperutil.state_str(state), str(instance_key))) - isinsert = not instance_key in uowtransaction.uow.identity_map and not postupdate and not _state_has_identity(state) + isinsert = not instance_key in uowtransaction.uow.identity_map and not postupdate and not has_identity params = {} value_params = {} hasdata = False @@ -1088,13 +1066,14 @@ class Mapper(object): if update: mapper = table_to_mapper[table] clause = sql.and_() + for col in mapper._pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col._label, type_=col.type)) + if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type)) + statement = table.update(clause) - rows = 0 - supports_sane_rowcount = True pks = mapper._pks_by_table[table] def comparator(a, b): for col in pks: @@ -1103,6 +1082,8 @@ class Mapper(object): return x return 0 update.sort(comparator) + + rows = 0 for rec in update: (state, params, mapper, connection, value_params) = rec c = connection.execute(statement.values(value_params), params) @@ -1126,11 +1107,10 @@ class Mapper(object): primary_key = c.last_inserted_ids() if primary_key is not None: - i = 0 - for col in mapper._pks_by_table[table]: + # set primary key attributes + for i, col in enumerate(mapper._pks_by_table[table]): if mapper._get_state_attr_by_column(state, col) is None and len(primary_key) > i: mapper._set_state_attr_by_column(state, col, primary_key[i]) - i+=1 mapper._postfetch(connection, table, state, c, c.last_inserted_params(), value_params) # synchronize newly inserted ids from one table to the next @@ -1144,6 +1124,7 @@ class Mapper(object): inserted_objects.add((state, connection)) if not postupdate: + # call after_XXX extensions for state, connection in inserted_objects: for mapper in _state_mapper(state).iterate_to_root(): if 'after_insert' in mapper.extension.methods: @@ -1167,10 +1148,7 @@ class Mapper(object): if c in postfetch_cols and (not c.key in params or c in value_params): prop = self._columntoproperty[c] deferred_props.append(prop.key) - continue - if c.primary_key or not c.key in params: - continue - if self._get_state_attr_by_column(state, c) != params[c.key]: + elif not c.primary_key and c.key in params and self._get_state_attr_by_column(state, c) != params[c.key]: self._set_state_attr_by_column(state, c, params[c.key]) if deferred_props: @@ -1295,14 +1273,6 @@ class Mapper(object): return self.__surrogate_mapper or self def _instance(self, context, row, result=None, skip_polymorphic=False, extension=None, only_load_props=None, refresh_instance=None): - """Pull an object instance from the given row and append it to - the given result list. - - If the instance already exists in the given identity map, its - not added. In either case, execute all the property loaders - on the instance to also process extra information in the row. - """ - if not extension: extension = self.extension @@ -1311,71 +1281,52 @@ class Mapper(object): if ret is not EXT_CONTINUE: row = ret - if refresh_instance is None: - if not skip_polymorphic and self.polymorphic_on is not None: - discriminator = row[self.polymorphic_on] - if discriminator is not None: - mapper = self.polymorphic_map[discriminator] - if mapper is not self: - if ('polymorphic_fetch', mapper) not in context.attributes: - context.attributes[('polymorphic_fetch', mapper)] = (self, [t for t in mapper.tables if t not in self.tables]) - row = self.translate_row(mapper, row) - return mapper._instance(context, row, result=result, skip_polymorphic=True) + if not refresh_instance and not skip_polymorphic and self.polymorphic_on is not None: + discriminator = row[self.polymorphic_on] + if discriminator is not None: + mapper = self.polymorphic_map[discriminator] + if mapper is not self: + if ('polymorphic_fetch', mapper) not in context.attributes: + context.attributes[('polymorphic_fetch', mapper)] = (self, [t for t in mapper.tables if t not in self.tables]) + row = self.translate_row(mapper, row) + return mapper._instance(context, row, result=result, skip_polymorphic=True) - # determine identity key if refresh_instance: identitykey = refresh_instance.dict['_instance_key'] else: identitykey = self.identity_key_from_row(row) - (session_identity_map, local_identity_map) = (context.session.identity_map, context.identity_map) + + session_identity_map = context.session.identity_map - # look in main identity map. if present, we only populate - # if repopulate flags are set. this block returns the instance. if identitykey in session_identity_map: instance = session_identity_map[identitykey] + state = instance._state if self.__should_log_debug: self.__log_debug("_instance(): using existing instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey))) - - isnew = False - if context.version_check and self.version_id_col is not None and self._get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]: - raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self._get_attr_by_column(instance, self.version_id_col), row[self.version_id_col])) - - if context.populate_existing or self.always_refresh or instance._state.trigger is not None: - instance._state.trigger = None - if identitykey not in local_identity_map: - local_identity_map[identitykey] = instance - isnew = True - if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, instancekey=identitykey, isnew=isnew, only_load_props=only_load_props) is EXT_CONTINUE: - self.populate_instance(context, instance, row, instancekey=identitykey, isnew=isnew, only_load_props=only_load_props) - - if 'append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: - if result is not None: - result.append(instance) + isnew = state.runid != context.runid + currentload = not isnew - return instance + if not currentload and context.version_check and self.version_id_col and self._get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]: + raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self._get_attr_by_column(instance, self.version_id_col), row[self.version_id_col])) - elif self.__should_log_debug: - self.__log_debug("_instance(): identity key %s not in session" % str(identitykey)) + else: + if self.__should_log_debug: + self.__log_debug("_instance(): identity key %s not in session" % str(identitykey)) - # look in identity map which is local to this load operation - if identitykey not in local_identity_map: - # check that sufficient primary key columns are present if self.allow_null_pks: - # check if *all* primary key cols in the result are None - this indicates - # an instance of the object is not present in the row. for x in identitykey[1]: if x is not None: break else: return None else: - # otherwise, check if *any* primary key cols in the result are None - this indicates - # an instance of the object is not present in the row. if None in identitykey[1]: return None + isnew = True + currentload = True if 'create_instance' in extension.methods: instance = extension.create_instance(self, context, row, self.class_) @@ -1386,30 +1337,29 @@ class Mapper(object): else: instance = attributes.new_instance(self.class_) - instance._entity_name = self.entity_name - instance._instance_key = identitykey - if self.__should_log_debug: self.__log_debug("_instance(): created new instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey))) - - local_identity_map[identitykey] = instance - isnew = True - else: - # instance is already present - instance = local_identity_map[identitykey] - isnew = False - - # populate. note that we still call this for an instance already loaded as additional collection state is present - # in subsequent rows (i.e. eagerly loaded collections) - flags = {'instancekey':identitykey, 'isnew':isnew} - if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, **flags) is EXT_CONTINUE: - self.populate_instance(context, instance, row, only_load_props=only_load_props, **flags) - if 'append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, **flags) is EXT_CONTINUE: - if result is not None: - result.append(instance) + + state = instance._state + instance._entity_name = self.entity_name + instance._instance_key = identitykey + instance._sa_session_id = context.session.hash_key + session_identity_map[identitykey] = instance + if currentload or context.populate_existing or self.always_refresh or state.trigger: + if isnew: + state.runid = context.runid + state.trigger = None + context.progress.add(state) + + if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: + self.populate_instance(context, instance, row, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) + + if result is not None and ('append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE): + result.append(instance) + return instance - + def _deferred_inheritance_condition(self, base_mapper, needs_tables): def visit_binary(binary): leftcol = binary.left @@ -1494,23 +1444,24 @@ class Mapper(object): selectcontext.exec_with_path(self, key, populator, instance, row, ispostselect=ispostselect, isnew=isnew, **flags) if self.non_primary: - selectcontext.attributes[('populating_mapper', id(instance))] = self + selectcontext.attributes[('populating_mapper', instance._state)] = self - def _post_instance(self, selectcontext, instance): + def _post_instance(self, selectcontext, state): post_processors = selectcontext.attributes[('post_processors', self, None)] for p in post_processors: - p(instance) + p(state.obj()) def _get_poly_select_loader(self, selectcontext, row): # 'select' or 'union'+col not present (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', self), (None, None)) - if hosted_mapper is None or len(needs_tables)==0 or hosted_mapper.polymorphic_fetch == 'deferred': + if hosted_mapper is None or not needs_tables or hosted_mapper.polymorphic_fetch == 'deferred': return cond, param_names = self._deferred_inheritance_condition(hosted_mapper, needs_tables) statement = sql.select(needs_tables, cond, use_labels=True) def post_execute(instance, **flags): - self.__log_debug("Post query loading instance " + mapperutil.instance_str(instance)) + if self.__should_log_debug: + self.__log_debug("Post query loading instance " + mapperutil.instance_str(instance)) identitykey = self.identity_key_from_instance(instance) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index d4e6ccb408..d75a9b8b27 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -8,12 +8,14 @@ from sqlalchemy import sql, util, exceptions, logging from sqlalchemy.sql import util as sql_util from sqlalchemy.sql import expression, visitors, operators from sqlalchemy.orm import mapper, object_mapper +from sqlalchemy.orm.mapper import _state_mapper from sqlalchemy.orm import util as mapperutil from itertools import chain import warnings __all__ = ['Query', 'QueryContext'] + class Query(object): """Encapsulates the object-fetching operations provided by Mappers.""" @@ -504,6 +506,11 @@ class Query(object): """apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``""" q = self._no_statement("order_by") + + if self._aliases is not None: + criterion = [expression._literal_as_text(o) for o in util.to_list(criterion) or []] + criterion = self._aliases.adapt_list(criterion) + if q._order_by is False: q._order_by = util.to_list(criterion) else: @@ -737,23 +744,31 @@ class Query(object): result.close() def instances(self, cursor, *mappers_or_columns, **kwargs): - """Return a list of mapped instances corresponding to the rows - in a given *cursor* (i.e. ``ResultProxy``). - - The \*mappers_or_columns and \**kwargs arguments are deprecated. - To add instances or columns to the results, use add_entity() - and add_column(). - """ - session = self.session context = kwargs.pop('querycontext', None) if context is None: context = QueryContext(self) - - process = [] + + context.runid = _new_runid() + mappers_or_columns = tuple(self._entities) + mappers_or_columns - if mappers_or_columns: + tuples = bool(mappers_or_columns) + + if self._primary_adapter: + def main(context, row): + return self.select_mapper._instance(context, self._primary_adapter(row), None, + extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance + ) + else: + def main(context, row): + return self.select_mapper._instance(context, row, None, + extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance + ) + + if tuples: + process = [] + process.append(main) for tup in mappers_or_columns: if isinstance(tup, tuple): (m, alias, alias_id) = tup @@ -761,63 +776,46 @@ class Query(object): else: clauses = alias = alias_id = None m = tup + if isinstance(m, type): m = mapper.class_mapper(m) + if isinstance(m, mapper.Mapper): def x(m): row_adapter = clauses is not None and clauses.row_decorator or (lambda row: row) - appender = [] def proc(context, row): - if not m._instance(context, row_adapter(row), appender): - appender.append(None) - process.append((proc, appender)) + return m._instance(context, row_adapter(row), None) + process.append(proc) x(m) elif isinstance(m, (sql.ColumnElement, basestring)): def y(m): row_adapter = clauses is not None and clauses.row_decorator or (lambda row: row) - res = [] def proc(context, row): - res.append(row_adapter(row)[m]) - process.append((proc, res)) + return row_adapter(row)[m] + process.append(proc) y(m) else: raise exceptions.InvalidRequestError("Invalid column expression '%r'" % m) - - result = [] + + context.progress = util.Set() + if tuples: + rows = util.OrderedSet() + for row in cursor.fetchall(): + rows.add(tuple(proc(context, row) for proc in process)) else: - result = util.UniqueAppender([]) - - primary_mapper_args = dict(extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance) - - for row in cursor.fetchall(): - if self._primary_adapter: - self.select_mapper._instance(context, self._primary_adapter(row), result, **primary_mapper_args) - else: - self.select_mapper._instance(context, row, result, **primary_mapper_args) - for proc in process: - proc[0](context, row) - - for instance in context.identity_map.values(): - context.attributes.get(('populating_mapper', id(instance)), object_mapper(instance))._post_instance(context, instance) - - # "refresh_instance" may be loaded from a row which has no primary key columns to identify it. - # this occurs during the load of the "joined" table in a joined-table inheritance mapper, and saves - # the need to join to the primary table in those cases. - if context.refresh_instance and context.only_load_props and context.refresh_instance.dict['_instance_key'] in context.identity_map: - # if refreshing partial instance, do special state commit - # affecting only the refreshed attributes + rows = util.UniqueAppender([]) + for row in cursor.fetchall(): + rows.append(main(context, row)) + + if context.refresh_instance and context.only_load_props and context.refresh_instance in context.progress: context.refresh_instance.commit(context.only_load_props) - del context.identity_map[context.refresh_instance.dict['_instance_key']] - - # store new stuff in the identity map - for instance in context.identity_map.values(): - session._register_persistent(instance) - - if mappers_or_columns: - return list(util.OrderedSet(zip(*([result] + [o[1] for o in process])))) - else: - return result.data + context.progress.remove(context.refresh_instance) + + for ii in context.progress: + context.attributes.get(('populating_mapper', ii), _state_mapper(ii))._post_instance(context, ii) + ii.commit_all() + return list(rows) def _get(self, key=None, ident=None, refresh_instance=None, lockmode=None, only_load_props=None): lockmode = lockmode or self._lockmode @@ -1323,7 +1321,6 @@ class QueryContext(object): self.version_check = query._version_check self.only_load_props = query._only_load_props self.refresh_instance = query._refresh_instance - self.identity_map = {} self.path = () self.primary_columns = [] self.secondary_columns = [] @@ -1340,3 +1337,15 @@ class QueryContext(object): finally: self.path = oldpath +_runid = 1 +_id_lock = util.threading.Lock() + +def _new_runid(): + global _runid + _id_lock.acquire() + try: + _runid += 1 + return _runid + finally: + _id_lock.release() + diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 3111b699d5..ce4e20bc93 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1015,11 +1015,6 @@ class Session(object): self._attach(instance) self.uow.register_deleted(instance) - def _register_persistent(self, instance): - instance._sa_session_id = self.hash_key - self.identity_map[instance._instance_key] = instance - instance._state.commit_all() - def _attach(self, instance): old_id = getattr(instance, '_sa_session_id', None) if old_id != self.hash_key: diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 2ba9d6be1e..60fc025790 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -603,8 +603,7 @@ class EagerLoader(AbstractRelationLoader): # so that we further descend into properties self.select_mapper._instance(selectcontext, decorated_row, None) else: - appender_key = ('appender', id(instance), self.key) - if isnew or appender_key not in selectcontext.attributes: + if isnew or self.key not in instance._state.appenders: # appender_key can be absent from selectcontext.attributes with isnew=False # when self-referential eager loading is used; the same instance may be present # in two distinct sets of result columns @@ -615,10 +614,9 @@ class EagerLoader(AbstractRelationLoader): collection = attributes.init_collection(instance, self.key) appender = util.UniqueAppender(collection, 'append_without_event') - # store it in the "scratch" area, which is local to this load operation. - selectcontext.attributes[appender_key] = appender + instance._state.appenders[self.key] = appender - result_list = selectcontext.attributes[appender_key] + result_list = instance._state.appenders[self.key] if self._should_log_debug: self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key)) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index d2782ec0ab..6e31b46468 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -174,6 +174,9 @@ class AliasedClauses(object): def adapt_clause(self, clause): return sql_util.ClauseAdapter(self.alias).traverse(clause, clone=True) + def adapt_list(self, clauses): + return sql_util.ClauseAdapter(self.alias).copy_and_process(clauses) + def _create_row_adapter(self): """Return a callable which, when passed a RowProxy, will return a new dict-like object diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 705168d209..3e26217c9b 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -488,26 +488,30 @@ class OrderedSet(Set): __or__ = union def intersection(self, other): - return self.__class__([a for a in self if a in other]) + other = Set(other) + return self.__class__([a for a in self if a in other]) __and__ = intersection def symmetric_difference(self, other): - result = self.__class__([a for a in self if a not in other]) - result.update([a for a in other if a not in self]) - return result + other = Set(other) + result = self.__class__([a for a in self if a not in other]) + result.update([a for a in other if a not in self]) + return result __xor__ = symmetric_difference def difference(self, other): - return self.__class__([a for a in self if a not in other]) + other = Set(other) + return self.__class__([a for a in self if a not in other]) __sub__ = difference def intersection_update(self, other): - Set.intersection_update(self, other) - self._list = [ a for a in self._list if a in other] - return self + other = Set(other) + Set.intersection_update(self, other) + self._list = [ a for a in self._list if a in other] + return self __iand__ = intersection_update @@ -520,9 +524,9 @@ class OrderedSet(Set): __ixor__ = symmetric_difference_update def difference_update(self, other): - Set.difference_update(self, other) - self._list = [ a for a in self._list if a in self] - return self + Set.difference_update(self, other) + self._list = [ a for a in self._list if a in self] + return self __isub__ = difference_update @@ -536,6 +540,7 @@ class IdentitySet(object): def __init__(self, iterable=None): self._members = _IterableUpdatableDict() + self._tempset = Set if iterable: for o in iterable: self.add(o) @@ -625,7 +630,7 @@ class IdentitySet(object): result = type(self)() # testlib.pragma exempt:__hash__ result._members.update( - Set(self._members.iteritems()).union(_iter_id(iterable))) + self._tempset(self._members.iteritems()).union(_iter_id(iterable))) return result def __or__(self, other): @@ -647,7 +652,7 @@ class IdentitySet(object): result = type(self)() # testlib.pragma exempt:__hash__ result._members.update( - Set(self._members.iteritems()).difference(_iter_id(iterable))) + self._tempset(self._members.iteritems()).difference(_iter_id(iterable))) return result def __sub__(self, other): @@ -669,7 +674,7 @@ class IdentitySet(object): result = type(self)() # testlib.pragma exempt:__hash__ result._members.update( - Set(self._members.iteritems()).intersection(_iter_id(iterable))) + self._tempset(self._members.iteritems()).intersection(_iter_id(iterable))) return result def __and__(self, other): @@ -691,7 +696,7 @@ class IdentitySet(object): result = type(self)() # testlib.pragma exempt:__hash__ result._members.update( - Set(self._members.iteritems()).symmetric_difference(_iter_id(iterable))) + self._tempset(self._members.iteritems()).symmetric_difference(_iter_id(iterable))) return result def __xor__(self, other): @@ -749,6 +754,7 @@ class OrderedIdentitySet(IdentitySet): def __init__(self, iterable=None): IdentitySet.__init__(self) self._members = OrderedDict() + self._tempset = OrderedSet if iterable: for o in iterable: self.add(o) diff --git a/test/base/utils.py b/test/base/utils.py index 932ad876a2..6e1b58c4a8 100644 --- a/test/base/utils.py +++ b/test/base/utils.py @@ -35,6 +35,18 @@ class OrderedDictTest(PersistTest): self.assert_(o.keys() == ['a', 'b', 'c', 'd', 'e', 'f']) self.assert_(o.values() == [1, 2, 3, 4, 5, 6]) +class OrderedSetTest(PersistTest): + def test_mutators_against_iter(self): + # testing a set modified against an iterator + o = util.OrderedSet([3,2, 4, 5]) + + self.assertEquals(o.difference(iter([3,4])), util.OrderedSet([2,5])) + + self.assertEquals(o.intersection(iter([3,4, 6])), util.OrderedSet([3, 4])) + + self.assertEquals(o.union(iter([3,4, 6])), util.OrderedSet([2, 3, 4, 5, 6])) + + class ColumnCollectionTest(PersistTest): def test_in(self): cc = sql.ColumnCollection() diff --git a/test/orm/attributes.py b/test/orm/attributes.py index 4d56a01ec1..3b3ed61025 100644 --- a/test/orm/attributes.py +++ b/test/orm/attributes.py @@ -4,6 +4,7 @@ import sqlalchemy.orm.attributes as attributes from sqlalchemy.orm.collections import collection from sqlalchemy import exceptions from testlib import * +from testlib import fixtures ROLLBACK_SUPPORTED=False @@ -13,7 +14,7 @@ class MyTest(object):pass class MyTest2(object):pass class AttributesTest(PersistTest): - """tests for the attributes.py module, which deals with tracking attribute changes on an object.""" + def test_basic(self): class User(object):pass @@ -23,31 +24,19 @@ class AttributesTest(PersistTest): attributes.register_attribute(User, 'email_address', uselist = False, useobject=False) u = User() - print repr(u.__dict__) - u.user_id = 7 u.user_name = 'john' u.email_address = 'lala@123.com' - print repr(u.__dict__) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') u._state.commit_all() - print repr(u.__dict__) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') u.user_name = 'heythere' u.email_address = 'foo@bar.com' - print repr(u.__dict__) self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.email_address == 'foo@bar.com') - if ROLLBACK_SUPPORTED: - attributes.rollback(u) - print repr(u.__dict__) - self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') - def test_pickleness(self): - - attributes.register_class(MyTest) attributes.register_class(MyTest2) attributes.register_attribute(MyTest, 'user_id', uselist = False, useobject=False) @@ -68,21 +57,21 @@ class AttributesTest(PersistTest): pk_o = pickle.dumps(o) o2 = pickle.loads(pk_o) + pk_o2 = pickle.dumps(o2) # so... pickle is creating a new 'mt2' string after a roundtrip here, # so we'll brute-force set it to be id-equal to the original string - o_mt2_str = [ k for k in o.__dict__ if k == 'mt2'][0] - o2_mt2_str = [ k for k in o2.__dict__ if k == 'mt2'][0] - self.assert_(o_mt2_str == o2_mt2_str) - self.assert_(o_mt2_str is not o2_mt2_str) - # change the id of o2.__dict__['mt2'] - former = o2.__dict__['mt2'] - del o2.__dict__['mt2'] - o2.__dict__[o_mt2_str] = former - - pk_o2 = pickle.dumps(o2) - - self.assert_(pk_o == pk_o2) + if False: + o_mt2_str = [ k for k in o.__dict__ if k == 'mt2'][0] + o2_mt2_str = [ k for k in o2.__dict__ if k == 'mt2'][0] + self.assert_(o_mt2_str == o2_mt2_str) + self.assert_(o_mt2_str is not o2_mt2_str) + # change the id of o2.__dict__['mt2'] + former = o2.__dict__['mt2'] + del o2.__dict__['mt2'] + o2.__dict__[o_mt2_str] = former + + self.assert_(pk_o == pk_o2) # the above is kind of distrurbing, so let's do it again a little # differently. the string-id in serialization thing is just an @@ -119,8 +108,6 @@ class AttributesTest(PersistTest): attributes.register_attribute(Address, 'email_address', uselist = False, useobject=False) u = User() - print repr(u.__dict__) - u.user_id = 7 u.user_name = 'john' u.addresses = [] @@ -129,10 +116,8 @@ class AttributesTest(PersistTest): a.email_address = 'lala@123.com' u.addresses.append(a) - print repr(u.__dict__) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') u, a._state.commit_all() - print repr(u.__dict__) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') u.user_name = 'heythere' @@ -140,16 +125,7 @@ class AttributesTest(PersistTest): a.address_id = 11 a.email_address = 'foo@bar.com' u.addresses.append(a) - print repr(u.__dict__) self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com') - - if ROLLBACK_SUPPORTED: - attributes.rollback(u, a) - print repr(u.__dict__) - print repr(u.addresses[0].__dict__) - self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') - self.assert_(len(attributes.get_history(u, 'addresses').unchanged_items()) == 1) - def test_lazytrackparent(self): """test that the "hasparent" flag works properly when lazy loaders and backrefs are used""" @@ -225,9 +201,9 @@ class AttributesTest(PersistTest): attributes.register_attribute(Foo, 'element', uselist=False, useobject=True) x = Bar() x.element = 'this is the element' - (added, unchanged, deleted) = attributes.get_history(x._state, 'element') - assert added == ['this is the element'] + self.assertEquals(attributes.get_history(x._state, 'element'), (['this is the element'],[], [])) x._state.commit_all() + (added, unchanged, deleted) = attributes.get_history(x._state, 'element') assert added == [] assert unchanged == ['this is the element'] @@ -235,21 +211,19 @@ class AttributesTest(PersistTest): def test_lazyhistory(self): """tests that history functions work with lazy-loading attributes""" - class Foo(object):pass - class Bar(object): - def __init__(self, id): - self.id = id - def __repr__(self): - return "Bar: id %d" % self.id - + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass attributes.register_class(Foo) attributes.register_class(Bar) + bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)] def func1(): return "this is func 1" def func2(): - return [Bar(1), Bar(2), Bar(3)] + return [bar1, bar2, bar3] attributes.register_attribute(Foo, 'col1', uselist=False, callable_=lambda o:func1, useobject=True) attributes.register_attribute(Foo, 'col2', uselist=True, callable_=lambda o:func2, useobject=True) @@ -257,9 +231,11 @@ class AttributesTest(PersistTest): x = Foo() x._state.commit_all() - x.col2.append(Bar(4)) + x.col2.append(bar4) (added, unchanged, deleted) = attributes.get_history(x._state, 'col2') - + + self.assertEquals(set(unchanged), set([bar1, bar2, bar3])) + self.assertEquals(added, [bar4]) def test_parenttrack(self): class Foo(object):pass @@ -541,6 +517,284 @@ class DeferredBackrefTest(PersistTest): called[0] = 0 lazy_load = (p1, p2, p3) = [Post("post 1"), Post("post 2"), Post("post 3")] + +class HistoryTest(PersistTest): + def test_scalar(self): + class Foo(fixtures.Base): + pass + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False) + + # case 1. new object + f = Foo() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], [])) + + f.someattr = "hi" + self.assertEquals(attributes.get_history(f._state, 'someattr'), (['hi'], [], [])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['hi'], [])) + f.someattr = 'there' + + self.assertEquals(attributes.get_history(f._state, 'someattr'), (['there'], [], ['hi'])) + f._state.commit(['someattr']) + + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['there'], [])) + + # case 2. object with direct dictionary settings (similar to a load operation) + f = Foo() + f.__dict__['someattr'] = 'new' + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['new'], [])) + + f.someattr = 'old' + self.assertEquals(attributes.get_history(f._state, 'someattr'), (['old'], [], ['new'])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['old'], [])) + + # setting None on uninitialized is currently a change for a scalar attribute + # no lazyload occurs so this allows overwrite operation to proceed + f = Foo() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], [])) + f.someattr = None + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], [])) + + f = Foo() + f.__dict__['someattr'] = 'new' + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['new'], [])) + f.someattr = None + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], ['new'])) + + def test_mutable_scalar(self): + class Foo(fixtures.Base): + pass + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False, mutable_scalars=True, copy_function=dict) + + # case 1. new object + f = Foo() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], [])) + + f.someattr = {'foo':'hi'} + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([{'foo':'hi'}], [], [])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'hi'}], [])) + self.assertEquals(f._state.committed_state['someattr'], {'foo':'hi'}) + + f.someattr['foo'] = 'there' + self.assertEquals(f._state.committed_state['someattr'], {'foo':'hi'}) + + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([{'foo':'there'}], [], [{'foo':'hi'}])) + f._state.commit(['someattr']) + + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'there'}], [])) + + # case 2. object with direct dictionary settings (similar to a load operation) + f = Foo() + f.__dict__['someattr'] = {'foo':'new'} + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'new'}], [])) + + f.someattr = {'foo':'old'} + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([{'foo':'old'}], [], [{'foo':'new'}])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'old'}], [])) + + + def test_use_object(self): + class Foo(fixtures.Base): + pass + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=True) + + # case 1. new object + f = Foo() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [None], [])) + + f.someattr = "hi" + self.assertEquals(attributes.get_history(f._state, 'someattr'), (['hi'], [], [])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['hi'], [])) + + f.someattr = 'there' + + self.assertEquals(attributes.get_history(f._state, 'someattr'), (['there'], [], ['hi'])) + f._state.commit(['someattr']) + + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['there'], [])) + + # case 2. object with direct dictionary settings (similar to a load operation) + f = Foo() + f.__dict__['someattr'] = 'new' + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['new'], [])) + + f.someattr = 'old' + self.assertEquals(attributes.get_history(f._state, 'someattr'), (['old'], [], ['new'])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['old'], [])) + + # setting None on uninitialized is currently not a change for an object attribute + # (this is different than scalar attribute). a lazyload has occured so if its + # None, its really None + f = Foo() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [None], [])) + f.someattr = None + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [None], [])) + + f = Foo() + f.__dict__['someattr'] = 'new' + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['new'], [])) + f.someattr = None + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], ['new'])) + + def test_object_collections_set(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + def __nonzero__(self): + assert False + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=True, useobject=True) + + hi = Bar(name='hi') + there = Bar(name='there') + old = Bar(name='old') + new = Bar(name='new') + + # case 1. new object + f = Foo() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], [])) + + f.someattr = [hi] + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], [])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], [])) + + f.someattr = [there] + + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [], [hi])) + f._state.commit(['someattr']) + + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [there], [])) + + # case 2. object with direct settings (similar to a load operation) + f = Foo() + collection = attributes.init_collection(f, 'someattr') + collection.append_without_event(new) + f._state.commit_all() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], [])) + + f.someattr = [old] + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old], [], [new])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [old], [])) + + def test_object_collections_mutate(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=True, useobject=True) + attributes.register_attribute(Foo, 'id', uselist=False, useobject=False) + + hi = Bar(name='hi') + there = Bar(name='there') + old = Bar(name='old') + new = Bar(name='new') + + # case 1. new object + f = Foo(id=1) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], [])) + + f.someattr.append(hi) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], [])) + + f._state.commit(['someattr']) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], [])) + + f.someattr.append(there) + + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [hi], [])) + f._state.commit(['someattr']) + + self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set(), set([there, hi]), set())) + + f.someattr.remove(there) + self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set(), set([hi]), set([there]))) + + f.someattr.append(old) + f.someattr.append(new) + self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set([new, old]), set([hi]), set([there]))) + f._state.commit(['someattr']) + self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set(), set([new, old, hi]), set())) + + f.someattr.pop(0) + self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set(), set([new, old]), set([hi]))) + + # case 2. object with direct settings (similar to a load operation) + f = Foo() + f.__dict__['id'] = 1 + collection = attributes.init_collection(f, 'someattr') + collection.append_without_event(new) + f._state.commit_all() + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], [])) + + f.someattr.append(old) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old], [new], [])) + + f._state.commit(['someattr']) + self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set([]), set([old, new]), set([]))) + + f = Foo() + collection = attributes.init_collection(f, 'someattr') + collection.append_without_event(new) + f._state.commit_all() + print f._state.dict + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], [])) + + f.id = 1 + f.someattr.remove(new) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], [new])) + + def test_collections_via_backref(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass + + attributes.register_class(Foo) + attributes.register_class(Bar) + attributes.register_attribute(Foo, 'bars', uselist=True, extension=attributes.GenericBackrefExtension('foo'), trackparent=True, useobject=True) + attributes.register_attribute(Bar, 'foo', uselist=False, extension=attributes.GenericBackrefExtension('bars'), trackparent=True, useobject=True) + + f1 = Foo() + b1 = Bar() + self.assertEquals(attributes.get_history(f1._state, 'bars'), ([], [], [])) + self.assertEquals(attributes.get_history(b1._state, 'foo'), ([], [None], [])) + + #b1.foo = f1 + f1.bars.append(b1) + self.assertEquals(attributes.get_history(f1._state, 'bars'), ([b1], [], [])) + self.assertEquals(attributes.get_history(b1._state, 'foo'), ([f1], [], [])) + + b2 = Bar() + f1.bars.append(b2) + self.assertEquals(tuple([set(x) for x in attributes.get_history(f1._state, 'bars')]), (set([b1, b2]), set([]), set([]))) + self.assertEquals(attributes.get_history(b1._state, 'foo'), ([f1], [], [])) + self.assertEquals(attributes.get_history(b2._state, 'foo'), ([f1], [], [])) + + if __name__ == "__main__": testbase.main() diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py index 192eafaed3..7a822234ce 100644 --- a/test/orm/eager_relations.py +++ b/test/orm/eager_relations.py @@ -545,24 +545,44 @@ class AddEntityTest(FixtureTest): def _assert_result(self): return [ ( - User(id=7, addresses=[Address(id=1)]), - Order(id=1, items=[Item(id=1), Item(id=2), Item(id=3)]), + User(id=7, + addresses=[Address(id=1)] + ), + Order(id=1, + items=[Item(id=1), Item(id=2), Item(id=3)] + ), ), ( - User(id=7, addresses=[Address(id=1)]), - Order(id=3, items=[Item(id=3), Item(id=4), Item(id=5)]), + User(id=7, + addresses=[Address(id=1)] + ), + Order(id=3, + items=[Item(id=3), Item(id=4), Item(id=5)] + ), ), ( - User(id=7, addresses=[Address(id=1)]), - Order(id=5, items=[Item(id=5)]), + User(id=7, + addresses=[Address(id=1)] + ), + Order(id=5, + items=[Item(id=5)] + ), ), ( - User(id=9, addresses=[Address(id=5)]), - Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)]), + User(id=9, + addresses=[Address(id=5)] + ), + Order(id=2, + items=[Item(id=1), Item(id=2), Item(id=3)] + ), ), ( - User(id=9, addresses=[Address(id=5)]), - Order(id=4, items=[Item(id=1), Item(id=5)]), + User(id=9, + addresses=[Address(id=5)] + ), + Order(id=4, + items=[Item(id=1), Item(id=5)] + ), ) ] @@ -573,14 +593,14 @@ class AddEntityTest(FixtureTest): }) mapper(Address, addresses) mapper(Order, orders, properties={ - 'items':relation(Item, secondary=order_items, lazy=False) + 'items':relation(Item, secondary=order_items, lazy=False, order_by=items.c.id) }) mapper(Item, items) sess = create_session() def go(): - ret = sess.query(User).add_entity(Order).join('orders', aliased=True).all() + ret = sess.query(User).add_entity(Order).join('orders', aliased=True).order_by(User.id).order_by(Order.id).all() self.assertEquals(ret, self._assert_result()) self.assert_sql_count(testbase.db, go, 1) @@ -591,20 +611,20 @@ class AddEntityTest(FixtureTest): }) mapper(Address, addresses) mapper(Order, orders, properties={ - 'items':relation(Item, secondary=order_items) + 'items':relation(Item, secondary=order_items, order_by=items.c.id) }) mapper(Item, items) sess = create_session() def go(): - ret = sess.query(User).options(eagerload('addresses')).add_entity(Order).join('orders', aliased=True).all() + ret = sess.query(User).options(eagerload('addresses')).add_entity(Order).join('orders', aliased=True).order_by(User.id).order_by(Order.id).all() self.assertEquals(ret, self._assert_result()) self.assert_sql_count(testbase.db, go, 6) sess.clear() def go(): - ret = sess.query(User).options(eagerload('addresses')).add_entity(Order).options(eagerload('items', Order)).join('orders', aliased=True).all() + ret = sess.query(User).options(eagerload('addresses')).add_entity(Order).options(eagerload('items', Order)).join('orders', aliased=True).order_by(User.id).order_by(Order.id).all() self.assertEquals(ret, self._assert_result()) self.assert_sql_count(testbase.db, go, 1) diff --git a/test/orm/expire.py b/test/orm/expire.py index 7f0e9002bb..be9c881c0c 100644 --- a/test/orm/expire.py +++ b/test/orm/expire.py @@ -87,7 +87,7 @@ class ExpireTest(FixtureTest): orders.update(id=3).execute(description='order 3 modified') assert o.isopen == 1 - assert o._state.committed_state['description'] == 'order 3 modified' + assert o._state.dict['description'] == 'order 3 modified' def go(): sess.flush() self.assert_sql_count(testbase.db, go, 0) @@ -158,14 +158,14 @@ class ExpireTest(FixtureTest): sess.expire(o, attribute_names=['description']) assert 'id' in o.__dict__ assert 'description' not in o.__dict__ - assert o._state.committed_state['isopen'] == 1 + assert o._state.dict['isopen'] == 1 orders.update(orders.c.id==3).execute(description='order 3 modified') def go(): assert o.description == 'order 3 modified' self.assert_sql_count(testbase.db, go, 1) - assert o._state.committed_state['description'] == 'order 3 modified' + assert o._state.dict['description'] == 'order 3 modified' o.isopen = 5 sess.expire(o, attribute_names=['description']) @@ -178,7 +178,7 @@ class ExpireTest(FixtureTest): assert o.description == 'order 3 modified' self.assert_sql_count(testbase.db, go, 1) assert o.__dict__['isopen'] == 5 - assert o._state.committed_state['description'] == 'order 3 modified' + assert o._state.dict['description'] == 'order 3 modified' assert o._state.committed_state['isopen'] == 1 sess.flush() diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 3847a49a02..d8b552c69c 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -1207,8 +1207,8 @@ class MapperExtensionTest(MapperSuperTest): sess.flush() sess.delete(u) sess.flush() - assert methods == set(['load', 'append_result', 'before_delete', 'create_instance', 'translate_row', 'get', - 'after_delete', 'after_insert', 'before_update', 'before_insert', 'after_update', 'populate_instance']) + self.assertEquals(methods, set(['load', 'before_delete', 'create_instance', 'translate_row', 'get', + 'after_delete', 'after_insert', 'before_update', 'before_insert', 'after_update', 'populate_instance'])) class RequirementsTest(AssertMixin): diff --git a/test/orm/relationships.py b/test/orm/relationships.py index fe11553612..8bfc626521 100644 --- a/test/orm/relationships.py +++ b/test/orm/relationships.py @@ -344,6 +344,7 @@ class RelationTest3(PersistTest): s.save(j1) s.save(j2) + s.flush() s.clear() @@ -366,10 +367,10 @@ class RelationTest4(ORMTest): tableA = Table("A", metadata, Column("id",Integer,primary_key=True), Column("foo",Integer,), - ) + test_needs_fk=True) tableB = Table("B",metadata, Column("id",Integer,ForeignKey("A.id"),primary_key=True), - ) + test_needs_fk=True) def test_no_delete_PK_AtoB(self): """test that A cant be deleted without B because B would have no PK value""" class A(object):pass @@ -411,26 +412,6 @@ class RelationTest4(ORMTest): except exceptions.AssertionError, e: assert str(e).startswith("Dependency rule tried to blank-out primary key column 'B.id' on instance ") - def test_no_nullPK_BtoA(self): - class A(object):pass - class B(object):pass - mapper(B, tableB, properties={ - 'a':relation(A, cascade="save-update") - }) - mapper(A, tableA) - b1 = B() - b1.a = None - sess = create_session() - sess.save(b1) - try: - # this raises an error as of r3695. in that rev, the attributes package was modified so that a - # setting of "None" shows up as a change, which in turn fires off dependency.py and then triggers - # the rule. - sess.flush() - assert False - except exceptions.AssertionError, e: - assert str(e).startswith("Dependency rule tried to blank-out primary key column 'B.id' on instance ") - @testing.fails_on_everything_except('sqlite', 'mysql') def test_nullPKsOK_BtoA(self): # postgres cant handle a nullable PK column...? diff --git a/test/orm/session.py b/test/orm/session.py index 777f1afee5..db9245c728 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -8,6 +8,8 @@ from testlib import * from testlib.tables import * from testlib import fixtures, tables import pickle +import gc + class SessionTest(AssertMixin): def setUpAll(self): @@ -534,32 +536,36 @@ class SessionTest(AssertMixin): """test the weak-referencing identity map, which strongly-references modified items.""" s = create_session() - class User(object):pass + class User(fixtures.Base):pass mapper(User, users) - # save user - s.save(User()) + s.save(User(user_name='ed')) s.flush() + assert not s.dirty + user = s.query(User).one() - user = None - import gc + del user gc.collect() assert len(s.identity_map) == 0 assert len(s.identity_map.data) == 0 user = s.query(User).one() user.user_name = 'fred' - user = None + del user gc.collect() assert len(s.identity_map) == 1 assert len(s.identity_map.data) == 1 + assert len(s.dirty) == 1 s.flush() gc.collect() - assert len(s.identity_map) == 0 - assert len(s.identity_map.data) == 0 + assert not s.dirty + assert not s.identity_map + assert not s.identity_map.data - assert s.query(User).one().user_name == 'fred' + user = s.query(User).one() + assert user.user_name == 'fred' + assert s.identity_map def test_strong_ref(self): s = create_session(weak_identity_map=False) diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index 87a565f3e9..47ce70fa7a 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -762,10 +762,10 @@ class DefaultTest(ORMTest): mapper(Hoho, default_table) h1 = Hoho() Session.commit() - self.assert_(h1.foober == 'im foober') + self.assertEquals(h1.foober, 'im foober') h1.counter = 19 Session.commit() - self.assert_(h1.foober == 'im the update') + self.assertEquals(h1.foober, 'im the update') class OneToManyTest(ORMTest): metadata = tables.metadata