From: Mike Bayer Date: Mon, 25 May 2009 15:20:44 +0000 (+0000) Subject: merge -r5936:5974 of trunk X-Git-Tag: rel_0_6_6~212 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=10640dba13748a50d2cad34962094553b33f7d19;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git merge -r5936:5974 of trunk --- diff --git a/CHANGES b/CHANGES index e950ca0a1f..165d2ec079 100644 --- a/CHANGES +++ b/CHANGES @@ -4,14 +4,77 @@ CHANGES ======= +0.5.4p1 +======= + +- orm + - Fixed an attribute error introduced in 0.5.4 which would + occur when merge() was used with an incomplete object. + 0.5.4 ===== - orm + - Significant performance enhancements regarding Sessions/flush() + in conjunction with large mapper graphs, large numbers of + objects: + + - Removed all* O(N) scanning behavior from the flush() process, + i.e. operations that were scanning the full session, + including an extremely expensive one that was erroneously + assuming primary key values were changing when this + was not the case. + + * one edge case remains which may invoke a full scan, + if an existing primary key attribute is modified + to a new value. + + - The Session's "weak referencing" behavior is now *full* - + no strong references whatsoever are made to a mapped object + or related items/collections in its __dict__. Backrefs and + other cycles in objects no longer affect the Session's ability + to lose all references to unmodified objects. Objects with + pending changes still are maintained strongly until flush. + [ticket:1398] + + The implementation also improves performance by moving + the "resurrection" process of garbage collected items + to only be relevant for mappings that map "mutable" + attributes (i.e. PickleType, composite attrs). This removes + overhead from the gc process and simplifies internal + behavior. + + If a "mutable" attribute change is the sole change on an object + which is then dereferenced, the mapper will not have access to + other attribute state when the UPDATE is issued. This may present + itself differently to some MapperExtensions. + + The change also affects the internal attribute API, but not + the AttributeExtension interface nor any of the publically + documented attribute functions. + + - The unit of work no longer genererates a graph of "dependency" + processors for the full graph of mappers during flush(), instead + creating such processors only for those mappers which represent + objects with pending changes. This saves a tremendous number + of method calls in the context of a large interconnected + graph of mappers. + + - Cached a wasteful "table sort" operation that previously + occured multiple times per flush, also removing significant + method call count from flush(). + + - Other redundant behaviors have been simplified in + mapper._save_obj(). + - Modified query_cls on DynamicAttributeImpl to accept a full mixin version of the AppenderQuery, which allows subclassing the AppenderMixin. + - The "polymorphic discriminator" column may be part of a + primary key, and it will be populated with the correct + discriminator value. [ticket:1300] + - Fixed the evaluator not being able to evaluate IS NULL clauses. - Fixed the "set collection" function on "dynamic" relations to @@ -44,12 +107,20 @@ CHANGES - Fixed another location where autoflush was interfering with session.merge(). autoflush is disabled completely for the duration of merge() now. [ticket:1360] - + + - Fixed bug which prevented "mutable primary key" dependency + logic from functioning properly on a one-to-one + relation(). [ticket:1406] + - Fixed bug in relation(), introduced in 0.5.3, whereby a self referential relation from a base class to a joined-table subclass would not configure correctly. + - Fixed obscure mapper compilation issue when inheriting + mappers are used which would result in un-initialized + attributes. + - Fixed documentation for session weak_identity_map - the default value is True, indicating a weak referencing map in use. @@ -62,6 +133,11 @@ CHANGES - Fixed Query.update() and Query.delete() failures with eagerloaded relations. [ticket:1378] + - It is now an error to specify both columns of a binary primaryjoin + condition in the foreign_keys or remote_side collection. Whereas + previously it was just nonsensical, but would succeed in a + non-deterministic way. + - schema - Added a quote_schema() method to the IdentifierPreparer class so that dialects can override how schemas get handled. This @@ -69,6 +145,18 @@ CHANGES identifiers, such as 'database.owner'. [ticket: 594, 1341] - sql + - Back-ported the "compiler" extension from SQLA 0.6. This + is a standardized interface which allows the creation of custom + ClauseElement subclasses and compilers. In particular it's + handy as an alternative to text() when you'd like to + build a construct that has database-specific compilations. + See the extension docs for details. + + - Exception messages are truncated when the list of bound + parameters is larger than 10, preventing enormous + multi-page exceptions from filling up screens and logfiles + for large executemany() statements. [ticket:1413] + - ``sqlalchemy.extract()`` is now dialect sensitive and can extract components of timestamps idiomatically across the supported databases, including SQLite. @@ -77,6 +165,11 @@ CHANGES ForeignKey constructed from __clause_element__() style construct (i.e. declarative columns). [ticket:1353] +- mysql + - Reflecting a FOREIGN KEY construct will take into account + a dotted schema.tablename combination, if the foreign key + references a table in a remote schema. [ticket:1405] + - mssql - Modified how savepoint logic works to prevent it from stepping on non-savepoint oriented routines. Savepoint @@ -90,6 +183,9 @@ CHANGES since it is only used by mssql now. [ticket:1343] - sqlite + - Corrected the SLBoolean type so that it properly treats only 1 + as True. [ticket:1402] + - Corrected the float type so that it correctly maps to a SLFloat type when being reflected. [ticket:1273] diff --git a/doc/build/sqlexpression.rst b/doc/build/sqlexpression.rst index 70aaf6dced..4d54d036bd 100644 --- a/doc/build/sqlexpression.rst +++ b/doc/build/sqlexpression.rst @@ -260,7 +260,7 @@ Integer indexes work as well: >>> row = result.fetchone() >>> print "name:", row[1], "; fullname:", row[2] - name: jack ; fullname: Jack Jones + name: wendy ; fullname: Wendy Williams But another way, whose usefulness will become apparent later on, is to use the ``Column`` objects directly as keys: diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 5d15f5d763..1b1b968557 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -109,6 +109,6 @@ from sqlalchemy.engine import create_engine, engine_from_config __all__ = sorted(name for name, obj in locals().items() if not (name.startswith('_') or inspect.ismodule(obj))) -__version__ = '0.5.3' +__version__ = '0.6beta1' del inspect, sys diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index fd5ba7348e..3bb6536a3c 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -2507,7 +2507,7 @@ class MySQLTableDefinitionParser(object): r'%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' r'FOREIGN KEY +' r'\((?P[^\)]+?)\) REFERENCES +' - r'(?P%(iq)s[^%(fq)s]+%(fq)s) +' + r'(?P
%(iq)s[^%(fq)s]+%(fq)s(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +' r'\((?P[^\)]+?)\)' r'(?: +(?PMATCH \w+))?' r'(?: +ON DELETE (?P%(on)s))?' diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 0c7400c2bd..c6228ca2f3 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -134,7 +134,7 @@ class SLBoolean(sqltypes.Boolean): def process(value): if value is None: return None - return value and True or False + return value == 1 return process colspecs = { diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index 799abbf0d3..d9fdd5df92 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -132,6 +132,11 @@ class DBAPIError(SQLAlchemyError): self.connection_invalidated = connection_invalidated def __str__(self): + if len(self.params) > 10: + return ' '.join((SQLAlchemyError.__str__(self), + repr(self.statement), + repr(self.params[:2]), + '... and a total of %i bound parameters' % len(self.params))) return ' '.join((SQLAlchemyError.__str__(self), repr(self.statement), repr(self.params))) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index a6861ee452..1df37b4e1e 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -20,14 +20,13 @@ import types import weakref from sqlalchemy import util -from sqlalchemy.util import EMPTY_SET from sqlalchemy.orm import interfaces, collections, exc import sqlalchemy.exceptions as sa_exc # lazy imports _entity_info = None identity_equal = None - +state = None PASSIVE_NORESULT = util.symbol('PASSIVE_NORESULT') ATTR_WAS_SET = util.symbol('ATTR_WAS_SET') @@ -105,7 +104,7 @@ class QueryableAttribute(interfaces.PropComparator): self.parententity = parententity def get_history(self, instance, **kwargs): - return self.impl.get_history(instance_state(instance), **kwargs) + return self.impl.get_history(instance_state(instance), instance_dict(instance), **kwargs) def __selectable__(self): # TODO: conditionally attach this method based on clause_element ? @@ -148,15 +147,15 @@ class InstrumentedAttribute(QueryableAttribute): """Public-facing descriptor, placed in the mapped class dictionary.""" def __set__(self, instance, value): - self.impl.set(instance_state(instance), value, None) + self.impl.set(instance_state(instance), instance_dict(instance), value, None) def __delete__(self, instance): - self.impl.delete(instance_state(instance)) + self.impl.delete(instance_state(instance), instance_dict(instance)) def __get__(self, instance, owner): if instance is None: return self - return self.impl.get(instance_state(instance)) + return self.impl.get(instance_state(instance), instance_dict(instance)) class _ProxyImpl(object): accepts_scalar_loader = False @@ -335,7 +334,7 @@ class AttributeImpl(object): else: state.callables[self.key] = callable_ - def get_history(self, state, passive=PASSIVE_OFF): + def get_history(self, state, dict_, passive=PASSIVE_OFF): raise NotImplementedError() def _get_callable(self, state): @@ -346,13 +345,13 @@ class AttributeImpl(object): else: return None - def initialize(self, state): + def initialize(self, state, dict_): """Initialize this attribute on the given object instance with an empty value.""" - state.dict[self.key] = None + dict_[self.key] = None return None - def get(self, state, passive=PASSIVE_OFF): + def get(self, state, dict_, passive=PASSIVE_OFF): """Retrieve a value from the given object. If a callable is assembled on this object's attribute, and @@ -361,7 +360,7 @@ class AttributeImpl(object): """ try: - return state.dict[self.key] + return dict_[self.key] except KeyError: # if no history, check for lazy callables, etc. if state.committed_state.get(self.key, NEVER_SET) is NEVER_SET: @@ -374,25 +373,25 @@ class AttributeImpl(object): return PASSIVE_NORESULT value = callable_() if value is not ATTR_WAS_SET: - return self.set_committed_value(state, value) + return self.set_committed_value(state, dict_, value) else: - if self.key not in state.dict: - return self.get(state, passive=passive) - return state.dict[self.key] + if self.key not in dict_: + return self.get(state, dict_, passive=passive) + return dict_[self.key] # Return a new, empty value - return self.initialize(state) + return self.initialize(state, dict_) - def append(self, state, value, initiator, passive=PASSIVE_OFF): - self.set(state, value, initiator) + def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + self.set(state, dict_, value, initiator) - def remove(self, state, value, initiator, passive=PASSIVE_OFF): - self.set(state, None, initiator) + def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + self.set(state, dict_, None, initiator) - def set(self, state, value, initiator): + def set(self, state, dict_, value, initiator): raise NotImplementedError() - def get_committed_value(self, state, passive=PASSIVE_OFF): + def get_committed_value(self, state, dict_, passive=PASSIVE_OFF): """return the unchanged value of this attribute""" if self.key in state.committed_state: @@ -401,12 +400,12 @@ class AttributeImpl(object): else: return state.committed_state.get(self.key) else: - return self.get(state, passive=passive) + return self.get(state, dict_, passive=passive) - def set_committed_value(self, state, value): + def set_committed_value(self, state, dict_, value): """set an attribute value on the given instance and 'commit' it.""" - state.commit([self.key]) + state.commit(dict_, [self.key]) state.callables.pop(self.key, None) state.dict[self.key] = value @@ -419,45 +418,45 @@ class ScalarAttributeImpl(AttributeImpl): accepts_scalar_loader = True uses_objects = False - def delete(self, state): + def delete(self, state, dict_): # TODO: catch key errors, convert to attributeerror? if self.active_history or self.extensions: - old = self.get(state) + old = self.get(state, dict_) else: - old = state.dict.get(self.key, NO_VALUE) + old = dict_.get(self.key, NO_VALUE) - state.modified_event(self, False, old) + state.modified_event(dict_, self, False, old) if self.extensions: - self.fire_remove_event(state, old, None) - del state.dict[self.key] + self.fire_remove_event(state, dict_, old, None) + del dict_[self.key] - def get_history(self, state, passive=PASSIVE_OFF): + def get_history(self, state, dict_, passive=PASSIVE_OFF): return History.from_attribute( - self, state, state.dict.get(self.key, NO_VALUE)) + self, state, dict_.get(self.key, NO_VALUE)) - def set(self, state, value, initiator): + def set(self, state, dict_, value, initiator): if initiator is self: return if self.active_history or self.extensions: - old = self.get(state) + old = self.get(state, dict_) else: - old = state.dict.get(self.key, NO_VALUE) + old = dict_.get(self.key, NO_VALUE) - state.modified_event(self, False, old) + state.modified_event(dict_, self, False, old) if self.extensions: - value = self.fire_replace_event(state, value, old, initiator) - state.dict[self.key] = value + value = self.fire_replace_event(state, dict_, value, old, initiator) + dict_[self.key] = value - def fire_replace_event(self, state, value, previous, initiator): + def fire_replace_event(self, state, dict_, value, previous, initiator): for ext in self.extensions: value = ext.set(state, value, previous, initiator or self) return value - def fire_remove_event(self, state, value, initiator): + def fire_remove_event(self, state, dict_, value, initiator): for ext in self.extensions: ext.remove(state, value, initiator or self) @@ -483,29 +482,48 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl): raise sa_exc.ArgumentError("MutableScalarAttributeImpl requires a copy function") self.copy = copy_function - def get_history(self, state, passive=PASSIVE_OFF): + def get_history(self, state, dict_, passive=PASSIVE_OFF): + if not dict_: + v = state.committed_state.get(self.key, NO_VALUE) + else: + v = dict_.get(self.key, NO_VALUE) + return History.from_attribute( - self, state, state.dict.get(self.key, NO_VALUE)) + self, state, v) - def commit_to_state(self, state, dest): - dest[self.key] = self.copy(state.dict[self.key]) + def commit_to_state(self, state, dict_, dest): + dest[self.key] = self.copy(dict_[self.key]) - def check_mutable_modified(self, state): - (added, unchanged, deleted) = self.get_history(state, passive=PASSIVE_NO_INITIALIZE) + def check_mutable_modified(self, state, dict_): + (added, unchanged, deleted) = self.get_history(state, dict_, passive=PASSIVE_NO_INITIALIZE) return bool(added or deleted) - def set(self, state, value, initiator): + def get(self, state, dict_, passive=PASSIVE_OFF): + if self.key not in state.mutable_dict: + ret = ScalarAttributeImpl.get(self, state, dict_, passive=passive) + if ret is not PASSIVE_NORESULT: + state.mutable_dict[self.key] = ret + return ret + else: + return state.mutable_dict[self.key] + + def delete(self, state, dict_): + ScalarAttributeImpl.delete(self, state, dict_) + state.mutable_dict.pop(self.key) + + def set(self, state, dict_, value, initiator): if initiator is self: return - state.modified_event(self, True, NEVER_SET) - + state.modified_event(dict_, self, True, NEVER_SET) + if self.extensions: - old = self.get(state) - value = self.fire_replace_event(state, value, old, initiator) - state.dict[self.key] = value + old = self.get(state, dict_) + value = self.fire_replace_event(state, dict_, value, old, initiator) + dict_[self.key] = value else: - state.dict[self.key] = value + dict_[self.key] = value + state.mutable_dict[self.key] = value class ScalarObjectAttributeImpl(ScalarAttributeImpl): @@ -526,22 +544,22 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): if compare_function is None: self.is_equal = identity_equal - def delete(self, state): - old = self.get(state) - self.fire_remove_event(state, old, self) - del state.dict[self.key] + def delete(self, state, dict_): + old = self.get(state, dict_) + self.fire_remove_event(state, dict_, old, self) + del dict_[self.key] - def get_history(self, state, passive=PASSIVE_OFF): - if self.key in state.dict: - return History.from_attribute(self, state, state.dict[self.key]) + def get_history(self, state, dict_, passive=PASSIVE_OFF): + if self.key in dict_: + return History.from_attribute(self, state, dict_[self.key]) else: - current = self.get(state, passive=passive) + current = self.get(state, dict_, passive=passive) if current is PASSIVE_NORESULT: return HISTORY_BLANK else: return History.from_attribute(self, state, current) - def set(self, state, value, initiator): + def set(self, state, dict_, value, initiator): """Set a value on the given InstanceState. `initiator` is the ``InstrumentedAttribute`` that initiated the @@ -553,12 +571,12 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): return # may want to add options to allow the get() here to be passive - old = self.get(state) - value = self.fire_replace_event(state, value, old, initiator) - state.dict[self.key] = value + old = self.get(state, dict_) + value = self.fire_replace_event(state, dict_, value, old, initiator) + dict_[self.key] = value - def fire_remove_event(self, state, value, initiator): - state.modified_event(self, False, value) + def fire_remove_event(self, state, dict_, value, initiator): + state.modified_event(dict_, self, False, value) if self.trackparent and value is not None: self.sethasparent(instance_state(value), False) @@ -566,8 +584,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): for ext in self.extensions: ext.remove(state, value, initiator or self) - def fire_replace_event(self, state, value, previous, initiator): - state.modified_event(self, False, previous) + def fire_replace_event(self, state, dict_, value, previous, initiator): + state.modified_event(dict_, self, False, previous) if self.trackparent: if previous is not value and previous is not None: @@ -615,15 +633,15 @@ class CollectionAttributeImpl(AttributeImpl): def __copy(self, item): return [y for y in list(collections.collection_adapter(item))] - def get_history(self, state, passive=PASSIVE_OFF): - current = self.get(state, passive=passive) + def get_history(self, state, dict_, passive=PASSIVE_OFF): + current = self.get(state, dict_, passive=passive) if current is PASSIVE_NORESULT: return HISTORY_BLANK else: return History.from_attribute(self, state, current) - def fire_append_event(self, state, value, initiator): - state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) + def fire_append_event(self, state, dict_, value, initiator): + state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) for ext in self.extensions: value = ext.append(state, value, initiator or self) @@ -633,11 +651,11 @@ class CollectionAttributeImpl(AttributeImpl): return value - def fire_pre_remove_event(self, state, initiator): - state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) + def fire_pre_remove_event(self, state, dict_, initiator): + state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) - def fire_remove_event(self, state, value, initiator): - state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) + def fire_remove_event(self, state, dict_, value, initiator): + state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) if self.trackparent and value is not None: self.sethasparent(instance_state(value), False) @@ -645,51 +663,51 @@ class CollectionAttributeImpl(AttributeImpl): for ext in self.extensions: ext.remove(state, value, initiator or self) - def delete(self, state): - if self.key not in state.dict: + def delete(self, state, dict_): + if self.key not in dict_: return - state.modified_event(self, True, NEVER_SET) + state.modified_event(dict_, self, True, NEVER_SET) - collection = self.get_collection(state) + collection = self.get_collection(state, state.dict) collection.clear_with_event() # TODO: catch key errors, convert to attributeerror? - del state.dict[self.key] + del dict_[self.key] - def initialize(self, state): + def initialize(self, state, dict_): """Initialize this attribute with an empty collection.""" _, user_data = self._initialize_collection(state) - state.dict[self.key] = user_data + dict_[self.key] = user_data return user_data def _initialize_collection(self, state): return state.manager.initialize_collection( self.key, state, self.collection_factory) - def append(self, state, value, initiator, passive=PASSIVE_OFF): + def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): if initiator is self: return - collection = self.get_collection(state, passive=passive) + collection = self.get_collection(state, dict_, passive=passive) if collection is PASSIVE_NORESULT: - value = self.fire_append_event(state, value, initiator) + value = self.fire_append_event(state, dict_, value, initiator) state.get_pending(self.key).append(value) else: collection.append_with_event(value, initiator) - def remove(self, state, value, initiator, passive=PASSIVE_OFF): + def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): if initiator is self: return - collection = self.get_collection(state, passive=passive) + collection = self.get_collection(state, state.dict, passive=passive) if collection is PASSIVE_NORESULT: - self.fire_remove_event(state, value, initiator) + self.fire_remove_event(state, dict_, value, initiator) state.get_pending(self.key).remove(value) else: collection.remove_with_event(value, initiator) - def set(self, state, value, initiator): + def set(self, state, dict_, value, initiator): """Set a value on the given object. `initiator` is the ``InstrumentedAttribute`` that initiated the @@ -701,10 +719,10 @@ class CollectionAttributeImpl(AttributeImpl): return self._set_iterable( - state, value, + state, dict_, value, lambda adapter, i: adapter.adapt_like_to_iterable(i)) - def _set_iterable(self, state, iterable, adapter=None): + def _set_iterable(self, state, dict_, iterable, adapter=None): """Set a collection value from an iterable of state-bearers. ``adapter`` is an optional callable invoked with a CollectionAdapter @@ -722,24 +740,24 @@ class CollectionAttributeImpl(AttributeImpl): else: new_values = list(iterable) - old = self.get(state) + old = self.get(state, dict_) # ignore re-assignment of the current collection, as happens # implicitly with in-place operators (foo.collection |= other) if old is iterable: return - state.modified_event(self, True, old) + state.modified_event(dict_, self, True, old) - old_collection = self.get_collection(state, old) + old_collection = self.get_collection(state, dict_, old) - state.dict[self.key] = user_data + dict_[self.key] = user_data collections.bulk_replace(new_values, old_collection, new_collection) old_collection.unlink(old) - def set_committed_value(self, state, value): + def set_committed_value(self, state, dict_, value): """Set an attribute value on the given instance and 'commit' it.""" collection, user_data = self._initialize_collection(state) @@ -751,13 +769,13 @@ class CollectionAttributeImpl(AttributeImpl): state.callables.pop(self.key, None) state.dict[self.key] = user_data - state.commit([self.key]) + state.commit(dict_, [self.key]) if self.key in state.pending: # pending items exist. issue a modified event, # add/remove new items. - state.modified_event(self, True, user_data) + state.modified_event(dict_, self, True, user_data) pending = state.pending.pop(self.key) added = pending.added_items @@ -769,14 +787,14 @@ class CollectionAttributeImpl(AttributeImpl): return user_data - def get_collection(self, state, user_data=None, passive=PASSIVE_OFF): + def get_collection(self, state, dict_, user_data=None, passive=PASSIVE_OFF): """Retrieve the CollectionAdapter associated with the given state. Creates a new CollectionAdapter if one does not exist. """ if user_data is None: - user_data = self.get(state, passive=passive) + user_data = self.get(state, dict_, passive=passive) if user_data is PASSIVE_NORESULT: return user_data @@ -799,327 +817,26 @@ class GenericBackrefExtension(interfaces.AttributeExtension): if oldchild is not None: # With lazy=None, there's no guarantee that the full collection is # present when updating via a backref. - old_state = instance_state(oldchild) + old_state, old_dict = instance_state(oldchild), instance_dict(oldchild) impl = old_state.get_impl(self.key) try: - impl.remove(old_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) + impl.remove(old_state, old_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) except (ValueError, KeyError, IndexError): pass if child is not None: - new_state = instance_state(child) - new_state.get_impl(self.key).append(new_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) + new_state, new_dict = instance_state(child), instance_dict(child) + new_state.get_impl(self.key).append(new_state, new_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) return child def append(self, state, child, initiator): - child_state = instance_state(child) - child_state.get_impl(self.key).append(child_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) + child_state, child_dict = instance_state(child), instance_dict(child) + child_state.get_impl(self.key).append(child_state, child_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) return child def remove(self, state, child, initiator): if child is not None: - child_state = instance_state(child) - child_state.get_impl(self.key).remove(child_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) - - -class InstanceState(object): - """tracks state information at the instance level.""" - - session_id = None - key = None - runid = None - expired_attributes = EMPTY_SET - load_options = EMPTY_SET - load_path = () - insert_order = None - - def __init__(self, obj, manager): - self.class_ = obj.__class__ - self.manager = manager - - self.obj = weakref.ref(obj, self._cleanup) - self.dict = obj.__dict__ - self.modified = False - self.callables = {} - self.expired = False - self.committed_state = {} - self.pending = {} - self.parents = {} - - def detach(self): - if self.session_id: - try: - del self.session_id - except AttributeError: - pass - - def dispose(self): - if self.session_id: - try: - del self.session_id - except AttributeError: - pass - del self.obj - del self.dict - - def _cleanup(self, ref): - self.dispose() - - def obj(self): - return None - - @util.memoized_property - def dict(self): - # return a blank dict - # if none is available, so that asynchronous gc - # doesn't blow up expiration operations in progress - # (usually expire_attributes) - return {} - - @property - def sort_key(self): - return self.key and self.key[1] or (self.insert_order, ) - - def check_modified(self): - if self.modified: - return True - else: - for key in self.manager.mutable_attributes: - if self.manager[key].impl.check_mutable_modified(self): - return True - else: - return False - - def initialize_instance(*mixed, **kwargs): - self, instance, args = mixed[0], mixed[1], mixed[2:] - manager = self.manager - - for fn in manager.events.on_init: - fn(self, instance, args, kwargs) - try: - return manager.events.original_init(*mixed[1:], **kwargs) - except: - for fn in manager.events.on_init_failure: - fn(self, instance, args, kwargs) - raise - - def get_history(self, key, **kwargs): - return self.manager.get_impl(key).get_history(self, **kwargs) - - def get_impl(self, key): - return self.manager.get_impl(key) - - def get_pending(self, key): - if key not in self.pending: - self.pending[key] = PendingCollection() - return self.pending[key] - - def value_as_iterable(self, key, passive=PASSIVE_OFF): - """return an InstanceState attribute as a list, - regardless of it being a scalar or collection-based - attribute. - - returns None if passive is not PASSIVE_OFF and the getter returns - PASSIVE_NORESULT. - """ - - impl = self.get_impl(key) - x = impl.get(self, passive=passive) - if x is PASSIVE_NORESULT: - - return None - elif hasattr(impl, 'get_collection'): - return impl.get_collection(self, x, passive=passive) - elif isinstance(x, list): - return x - else: - return [x] - - def _run_on_load(self, instance=None): - if instance is None: - instance = self.obj() - self.manager.events.run('on_load', instance) - - def __getstate__(self): - return {'key': self.key, - 'committed_state': self.committed_state, - 'pending': self.pending, - 'parents': self.parents, - 'modified': self.modified, - 'expired':self.expired, - 'load_options':self.load_options, - 'load_path':interfaces.serialize_path(self.load_path), - 'instance': self.obj(), - 'expired_attributes':self.expired_attributes, - 'callables': self.callables} - - def __setstate__(self, state): - self.committed_state = state['committed_state'] - self.parents = state['parents'] - self.key = state['key'] - self.session_id = None - self.pending = state['pending'] - self.modified = state['modified'] - self.obj = weakref.ref(state['instance']) - self.load_options = state['load_options'] or EMPTY_SET - self.load_path = interfaces.deserialize_path(state['load_path']) - self.class_ = self.obj().__class__ - self.manager = manager_of_class(self.class_) - self.dict = self.obj().__dict__ - self.callables = state['callables'] - self.runid = None - self.expired = state['expired'] - self.expired_attributes = state['expired_attributes'] - - def initialize(self, key): - self.manager.get_impl(key).initialize(self) - - def set_callable(self, key, callable_): - self.dict.pop(key, None) - self.callables[key] = callable_ - - def __call__(self): - """__call__ allows the InstanceState to act as a deferred - callable for loading expired attributes, which is also - serializable (picklable). - - """ - unmodified = self.unmodified - class_manager = self.manager - class_manager.deferred_scalar_loader(self, [ - attr.impl.key for attr in class_manager.attributes if - attr.impl.accepts_scalar_loader and - attr.impl.key in self.expired_attributes and - attr.impl.key in unmodified - ]) - for k in self.expired_attributes: - self.callables.pop(k, None) - del self.expired_attributes - return ATTR_WAS_SET - - @property - def unmodified(self): - """a set of keys which have no uncommitted changes""" - - return set( - key for key in self.manager.iterkeys() - if (key not in self.committed_state or - (key in self.manager.mutable_attributes and - not self.manager[key].impl.check_mutable_modified(self)))) - - @property - def unloaded(self): - """a set of keys which do not have a loaded value. - - This includes expired attributes and any other attribute that - was never populated or modified. - - """ - return set( - key for key in self.manager.iterkeys() - if key not in self.committed_state and key not in self.dict) - - def expire_attributes(self, attribute_names): - self.expired_attributes = set(self.expired_attributes) - - if attribute_names is None: - attribute_names = self.manager.keys() - self.expired = True - self.modified = False - filter_deferred = True - else: - filter_deferred = False - for key in attribute_names: - impl = self.manager[key].impl - if not filter_deferred or \ - not impl.dont_expire_missing or \ - key in self.dict: - self.expired_attributes.add(key) - if impl.accepts_scalar_loader: - self.callables[key] = self - self.dict.pop(key, None) - self.pending.pop(key, None) - self.committed_state.pop(key, None) - - def reset(self, key): - """remove the given attribute and any callables associated with it.""" - - self.dict.pop(key, None) - self.callables.pop(key, None) - - def modified_event(self, attr, should_copy, previous, passive=PASSIVE_OFF): - needs_committed = attr.key not in self.committed_state - - if needs_committed: - if previous is NEVER_SET: - if passive: - if attr.key in self.dict: - previous = self.dict[attr.key] - else: - previous = attr.get(self) - - if should_copy and previous not in (None, NO_VALUE, NEVER_SET): - previous = attr.copy(previous) - - if needs_committed: - self.committed_state[attr.key] = previous - - self.modified = True - - def commit(self, keys): - """Commit attributes. - - This is used by a partial-attribute load operation to mark committed - those attributes which were refreshed from the database. - - Attributes marked as "expired" can potentially remain "expired" after - this step if a value was not populated in state.dict. - - """ - class_manager = self.manager - for key in keys: - if key in self.dict and key in class_manager.mutable_attributes: - class_manager[key].impl.commit_to_state(self, self.committed_state) - else: - self.committed_state.pop(key, None) - - self.expired = False - # unexpire attributes which have loaded - for key in self.expired_attributes.intersection(keys): - if key in self.dict: - self.expired_attributes.remove(key) - self.callables.pop(key, None) - - def commit_all(self): - """commit all attributes unconditionally. - - This is used after a flush() or a full load/refresh - to remove all pending state from the instance. - - - all attributes are marked as "committed" - - the "strong dirty reference" is removed - - the "modified" flag is set to False - - any "expired" markers/callables are removed. - - Attributes marked as "expired" can potentially remain "expired" after this step - if a value was not populated in state.dict. - - """ - - self.committed_state = {} - self.pending = {} - - # unexpire attributes which have loaded - if self.expired_attributes: - for key in self.expired_attributes.intersection(self.dict): - self.callables.pop(key, None) - self.expired_attributes.difference_update(self.dict) - - for key in self.manager.mutable_attributes: - if key in self.dict: - self.manager[key].impl.commit_to_state(self, self.committed_state) - - self.modified = self.expired = False - self._strong_obj = None + child_state, child_dict = instance_state(child), instance_dict(child) + child_state.get_impl(self.key).remove(child_state, child_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) class Events(object): @@ -1128,6 +845,7 @@ class Events(object): self.on_init = () self.on_init_failure = () self.on_load = () + self.on_resurrect = () def run(self, event, *args, **kwargs): for fn in getattr(self, event): @@ -1153,7 +871,6 @@ class ClassManager(dict): STATE_ATTR = '_sa_instance_state' event_registry_factory = Events - instance_state_factory = InstanceState deferred_scalar_loader = None def __init__(self, class_): @@ -1177,7 +894,6 @@ class ClassManager(dict): def _configure_create_arguments(self, _source=None, - instance_state_factory=None, deferred_scalar_loader=None): """Accept extra **kw arguments passed to create_manager_for_cls. @@ -1192,11 +908,8 @@ class ClassManager(dict): """ if _source: - instance_state_factory = _source.instance_state_factory deferred_scalar_loader = _source.deferred_scalar_loader - if instance_state_factory: - self.instance_state_factory = instance_state_factory if deferred_scalar_loader: self.deferred_scalar_loader = deferred_scalar_loader @@ -1229,7 +942,16 @@ class ClassManager(dict): if self.new_init: self.uninstall_member('__init__') self.new_init = None - + + def _create_instance_state(self, instance): + global state + if state is None: + from sqlalchemy.orm import state + if self.mutable_attributes: + return state.MutableAttrInstanceState(instance, self) + else: + return state.InstanceState(instance, self) + def manage(self): """Mark this instance as the manager for its class.""" @@ -1337,11 +1059,11 @@ class ClassManager(dict): def new_instance(self, state=None): instance = self.class_.__new__(self.class_) - setattr(instance, self.STATE_ATTR, state or self.instance_state_factory(instance, self)) + setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance)) return instance def setup_instance(self, instance, state=None): - setattr(instance, self.STATE_ATTR, state or self.instance_state_factory(instance, self)) + setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance)) def teardown_instance(self, instance): delattr(instance, self.STATE_ATTR) @@ -1355,13 +1077,10 @@ class ClassManager(dict): if hasattr(instance, self.STATE_ATTR): return False else: - state = self.instance_state_factory(instance, self) + state = self._create_instance_state(instance) setattr(instance, self.STATE_ATTR, state) return state - def state_of(self, instance): - return getattr(instance, self.STATE_ATTR) - def state_getter(self): """Return a (instance) -> InstanceState callable. @@ -1372,6 +1091,9 @@ class ClassManager(dict): return attrgetter(self.STATE_ATTR) + def dict_getter(self): + return attrgetter('__dict__') + def has_state(self, instance): return hasattr(instance, self.STATE_ATTR) @@ -1392,6 +1114,9 @@ class _ClassInstrumentationAdapter(ClassManager): def __init__(self, class_, override, **kw): self._adapted = override + self._get_state = self._adapted.state_getter(class_) + self._get_dict = self._adapted.dict_getter(class_) + ClassManager.__init__(self, class_, **kw) def manage(self): @@ -1453,36 +1178,27 @@ class _ClassInstrumentationAdapter(ClassManager): self._adapted.initialize_instance_dict(self.class_, instance) if state is None: - state = self.instance_state_factory(instance, self) + state = self._create_instance_state(instance) # the given instance is assumed to have no state self._adapted.install_state(self.class_, instance, state) - state.dict = self._adapted.get_instance_dict(self.class_, instance) return state def teardown_instance(self, instance): self._adapted.remove_state(self.class_, instance) - def state_of(self, instance): - if hasattr(self._adapted, 'state_of'): - return self._adapted.state_of(self.class_, instance) - else: - getter = self._adapted.state_getter(self.class_) - return getter(instance) - def has_state(self, instance): - if hasattr(self._adapted, 'has_state'): - return self._adapted.has_state(self.class_, instance) - else: - try: - state = self.state_of(instance) - return True - except exc.NO_STATE: - return False + try: + state = self._get_state(instance) + return True + except exc.NO_STATE: + return False def state_getter(self): - return self._adapted.state_getter(self.class_) + return self._get_state + def dict_getter(self): + return self._get_dict class History(tuple): """A 3-tuple of added, unchanged and deleted values. @@ -1527,7 +1243,7 @@ class History(tuple): original = state.committed_state.get(attribute.key, NEVER_SET) if hasattr(attribute, 'get_collection'): - current = attribute.get_collection(state, current) + current = attribute.get_collection(state, state.dict, current) if original is NO_VALUE: return cls(list(current), (), ()) elif original is NEVER_SET: @@ -1564,30 +1280,8 @@ class History(tuple): HISTORY_BLANK = History(None, None, None) -class PendingCollection(object): - """A writable placeholder for an unloaded collection. - - Stores items appended to and removed from a collection that has not yet - been loaded. When the collection is loaded, the changes stored in - PendingCollection are applied to it to produce the final result. - - """ - def __init__(self): - self.deleted_items = util.IdentitySet() - self.added_items = util.OrderedIdentitySet() - - def append(self, value): - if value in self.deleted_items: - self.deleted_items.remove(value) - self.added_items.add(value) - - def remove(self, value): - if value in self.added_items: - self.added_items.remove(value) - self.deleted_items.add(value) - def _conditional_instance_state(obj): - if not isinstance(obj, InstanceState): + if not isinstance(obj, state.InstanceState): obj = instance_state(obj) return obj @@ -1697,15 +1391,16 @@ def init_collection(obj, key): this usage is deprecated. """ - - return init_state_collection(_conditional_instance_state(obj), key) + state = _conditional_instance_state(obj) + dict_ = state.dict + return init_state_collection(state, dict_, key) -def init_state_collection(state, key): +def init_state_collection(state, dict_, key): """Initialize a collection attribute and return the collection adapter.""" attr = state.get_impl(key) - user_data = attr.initialize(state) - return attr.get_collection(state, user_data) + user_data = attr.initialize(state, dict_) + return attr.get_collection(state, dict_, user_data) def set_committed_value(instance, key, value): """Set the value of an attribute with no history events. @@ -1722,8 +1417,8 @@ def set_committed_value(instance, key, value): as though it were part of its original loaded state. """ - state = instance_state(instance) - state.get_impl(key).set_committed_value(instance, key, value) + state, dict_ = instance_state(instance), instance_dict(instance) + state.get_impl(key).set_committed_value(state, dict_, key, value) def set_attribute(instance, key, value): """Set the value of an attribute, firing history events. @@ -1735,8 +1430,8 @@ def set_attribute(instance, key, value): by SQLAlchemy. """ - state = instance_state(instance) - state.get_impl(key).set(state, value, None) + state, dict_ = instance_state(instance), instance_dict(instance) + state.get_impl(key).set(state, dict_, value, None) def get_attribute(instance, key): """Get the value of an attribute, firing any callables required. @@ -1748,8 +1443,8 @@ def get_attribute(instance, key): by SQLAlchemy. """ - state = instance_state(instance) - return state.get_impl(key).get(state) + state, dict_ = instance_state(instance), instance_dict(instance) + return state.get_impl(key).get(state, dict_) def del_attribute(instance, key): """Delete the value of an attribute, firing history events. @@ -1761,8 +1456,8 @@ def del_attribute(instance, key): by SQLAlchemy. """ - state = instance_state(instance) - state.get_impl(key).delete(state) + state, dict_ = instance_state(instance), instance_dict(instance) + state.get_impl(key).delete(state, dict_) def is_instrumented(instance, key): """Return True if the given attribute on the given instance is instrumented @@ -1779,6 +1474,7 @@ class InstrumentationRegistry(object): _manager_finders = weakref.WeakKeyDictionary() _state_finders = util.WeakIdentityMapping() + _dict_finders = util.WeakIdentityMapping() _extended = False def create_manager_for_cls(self, class_, **kw): @@ -1813,6 +1509,7 @@ class InstrumentationRegistry(object): manager.factory = factory self._manager_finders[class_] = manager.manager_getter() self._state_finders[class_] = manager.state_getter() + self._dict_finders[class_] = manager.dict_getter() return manager def _collect_management_factories_for(self, cls): @@ -1852,6 +1549,7 @@ class InstrumentationRegistry(object): return finder(cls) def state_of(self, instance): + # this is only called when alternate instrumentation has been established if instance is None: raise AttributeError("None has no persistent state.") try: @@ -1859,21 +1557,15 @@ class InstrumentationRegistry(object): except KeyError: raise AttributeError("%r is not instrumented" % instance.__class__) - def state_or_default(self, instance, default=None): + def dict_of(self, instance): + # this is only called when alternate instrumentation has been established if instance is None: - return default + raise AttributeError("None has no persistent state.") try: - finder = self._state_finders[instance.__class__] + return self._dict_finders[instance.__class__](instance) except KeyError: - return default - else: - try: - return finder(instance) - except exc.NO_STATE: - return default - except: - raise - + raise AttributeError("%r is not instrumented" % instance.__class__) + def unregister(self, class_): if class_ in self._manager_finders: manager = self.manager_of_class(class_) @@ -1881,6 +1573,7 @@ class InstrumentationRegistry(object): manager.dispose() del self._manager_finders[class_] del self._state_finders[class_] + del self._dict_finders[class_] instrumentation_registry = InstrumentationRegistry() @@ -1894,12 +1587,14 @@ def _install_lookup_strategy(implementation): and unit tests specific to this behavior. """ - global instance_state + global instance_state, instance_dict if implementation is util.symbol('native'): instance_state = attrgetter(ClassManager.STATE_ATTR) + instance_dict = attrgetter("__dict__") else: instance_state = instrumentation_registry.state_of - + instance_dict = instrumentation_registry.dict_of + manager_of_class = instrumentation_registry.manager_of_class _create_manager_for_cls = instrumentation_registry.create_manager_for_cls _install_lookup_strategy(util.symbol('native')) diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 5903d34927..b865c11f46 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -472,6 +472,7 @@ class CollectionAdapter(object): """ def __init__(self, attr, owner_state, data): self.attr = attr + # TODO: figure out what this being a weakref buys us self._data = weakref.ref(data) self.owner_state = owner_state self.link_to_self(data) @@ -578,7 +579,7 @@ class CollectionAdapter(object): """ if initiator is not False and item is not None: - return self.attr.fire_append_event(self.owner_state, item, initiator) + return self.attr.fire_append_event(self.owner_state, self.owner_state.dict, item, initiator) else: return item @@ -591,7 +592,7 @@ class CollectionAdapter(object): """ if initiator is not False and item is not None: - self.attr.fire_remove_event(self.owner_state, item, initiator) + self.attr.fire_remove_event(self.owner_state, self.owner_state.dict, item, initiator) def fire_pre_remove_event(self, initiator=None): """Notify that an entity is about to be removed from the collection. @@ -600,7 +601,7 @@ class CollectionAdapter(object): fire_remove_event(). """ - self.attr.fire_pre_remove_event(self.owner_state, initiator=initiator) + self.attr.fire_pre_remove_event(self.owner_state, self.owner_state.dict, initiator=initiator) def __getstate__(self): return {'key': self.attr.key, diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index c4ba7852f9..f3820eb7cd 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -64,17 +64,21 @@ class DependencyProcessor(object): def register_dependencies(self, uowcommit): """Tell a ``UOWTransaction`` what mappers are dependent on which, with regards to the two or three mappers handled by - this ``PropertyLoader``. + this ``DependencyProcessor``. - Also register itself as a *processor* for one of its mappers, - which will be executed after that mapper's objects have been - saved or before they've been deleted. The process operation - manages attributes and dependent operations upon the objects - of one of the involved mappers. """ raise NotImplementedError() + def register_processors(self, uowcommit): + """Tell a ``UOWTransaction`` about this object as a processor, + which will be executed after that mapper's objects have been + saved or before they've been deleted. The process operation + manages attributes and dependent operations between two mappers. + + """ + raise NotImplementedError() + def whose_dependent_on_who(self, state1, state2): """Given an object pair assuming `obj2` is a child of `obj1`, return a tuple with the dependent object second, or None if @@ -181,9 +185,13 @@ class OneToManyDP(DependencyProcessor): if self.post_update: uowcommit.register_dependency(self.mapper, self.dependency_marker) uowcommit.register_dependency(self.parent, self.dependency_marker) - uowcommit.register_processor(self.dependency_marker, self, self.parent) else: uowcommit.register_dependency(self.parent, self.mapper) + + def register_processors(self, uowcommit): + if self.post_update: + uowcommit.register_processor(self.dependency_marker, self, self.parent) + else: uowcommit.register_processor(self.parent, self, self.parent) def process_dependencies(self, task, deplist, uowcommit, delete = False): @@ -257,11 +265,13 @@ class OneToManyDP(DependencyProcessor): uowcommit.register_object( attributes.instance_state(c), isdelete=True) - if not self.passive_updates and self._pks_changed(uowcommit, state): + if self._pks_changed(uowcommit, state): if not history: - history = uowcommit.get_attribute_history(state, self.key, passive=False) - for child in history.unchanged: - uowcommit.register_object(child) + history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_updates) + if history: + for child in history.unchanged: + if child is not None: + uowcommit.register_object(child) def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): source = state @@ -275,7 +285,7 @@ class OneToManyDP(DependencyProcessor): sync.populate(source, self.parent, dest, self.mapper, self.prop.synchronize_pairs) def _pks_changed(self, uowcommit, state): - return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs) + return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs) class DetectKeySwitch(DependencyProcessor): """a special DP that works for many-to-one relations, fires off for @@ -284,6 +294,9 @@ class DetectKeySwitch(DependencyProcessor): no_dependencies = True def register_dependencies(self, uowcommit): + pass + + def register_processors(self, uowcommit): uowcommit.register_processor(self.parent, self, self.mapper) def preprocess_dependencies(self, task, deplist, uowcommit, delete=False): @@ -314,11 +327,11 @@ class DetectKeySwitch(DependencyProcessor): elem.dict[self.key] is not None and attributes.instance_state(elem.dict[self.key]) in switchers ]: - uowcommit.register_object(s, listonly=self.passive_updates) + uowcommit.register_object(s) sync.populate(attributes.instance_state(s.dict[self.key]), self.mapper, s, self.parent, self.prop.synchronize_pairs) def _pks_changed(self, uowcommit, state): - return sync.source_changes(uowcommit, state, self.mapper, self.prop.synchronize_pairs) + return sync.source_modified(uowcommit, state, self.mapper, self.prop.synchronize_pairs) class ManyToOneDP(DependencyProcessor): def __init__(self, prop): @@ -329,12 +342,15 @@ class ManyToOneDP(DependencyProcessor): if self.post_update: uowcommit.register_dependency(self.mapper, self.dependency_marker) uowcommit.register_dependency(self.parent, self.dependency_marker) - uowcommit.register_processor(self.dependency_marker, self, self.parent) else: uowcommit.register_dependency(self.mapper, self.parent) + + def register_processors(self, uowcommit): + if self.post_update: + uowcommit.register_processor(self.dependency_marker, self, self.parent) + else: uowcommit.register_processor(self.mapper, self, self.parent) - def process_dependencies(self, task, deplist, uowcommit, delete=False): if delete: if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes == 'all': @@ -407,8 +423,10 @@ class ManyToManyDP(DependencyProcessor): uowcommit.register_dependency(self.parent, self.dependency_marker) uowcommit.register_dependency(self.mapper, self.dependency_marker) - uowcommit.register_processor(self.dependency_marker, self, self.parent) + def register_processors(self, uowcommit): + uowcommit.register_processor(self.dependency_marker, self, self.parent) + def process_dependencies(self, task, deplist, uowcommit, delete = False): connection = uowcommit.transaction.connection(self.mapper) secondary_delete = [] @@ -502,7 +520,7 @@ class ManyToManyDP(DependencyProcessor): sync.populate_dict(child, self.mapper, associationrow, self.prop.secondary_synchronize_pairs) def _pks_changed(self, uowcommit, state): - return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs) + return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs) class MapperStub(object): """Represent a many-to-many dependency within a flush @@ -526,6 +544,9 @@ class MapperStub(object): def _register_dependencies(self, uowcommit): pass + def _register_procesors(self, uowcommit): + pass + def _save_obj(self, *args, **kwargs): pass diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 3d31a686a2..70243291dc 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -55,21 +55,21 @@ class DynamicAttributeImpl(attributes.AttributeImpl): else: self.query_class = mixin_user_query(query_class) - def get(self, state, passive=False): + def get(self, state, dict_, passive=False): if passive: return self._get_collection_history(state, passive=True).added_items else: return self.query_class(self, state) - def get_collection(self, state, user_data=None, passive=True): + def get_collection(self, state, dict_, user_data=None, passive=True): if passive: return self._get_collection_history(state, passive=passive).added_items else: history = self._get_collection_history(state, passive=passive) return history.added_items + history.unchanged_items - def fire_append_event(self, state, value, initiator): - collection_history = self._modified_event(state) + def fire_append_event(self, state, dict_, value, initiator): + collection_history = self._modified_event(state, dict_) collection_history.added_items.append(value) for ext in self.extensions: @@ -78,8 +78,8 @@ class DynamicAttributeImpl(attributes.AttributeImpl): if self.trackparent and value is not None: self.sethasparent(attributes.instance_state(value), True) - def fire_remove_event(self, state, value, initiator): - collection_history = self._modified_event(state) + def fire_remove_event(self, state, dict_, value, initiator): + collection_history = self._modified_event(state, dict_) collection_history.deleted_items.append(value) if self.trackparent and value is not None: @@ -88,31 +88,31 @@ class DynamicAttributeImpl(attributes.AttributeImpl): for ext in self.extensions: ext.remove(state, value, initiator or self) - def _modified_event(self, state): + def _modified_event(self, state, dict_): if self.key not in state.committed_state: state.committed_state[self.key] = CollectionHistory(self, state) - state.modified_event(self, False, attributes.NEVER_SET, passive=attributes.PASSIVE_NO_INITIALIZE) + state.modified_event(dict_, self, False, attributes.NEVER_SET, passive=attributes.PASSIVE_NO_INITIALIZE) # this is a hack to allow the _base.ComparableEntity fixture # to work - state.dict[self.key] = True + dict_[self.key] = True return state.committed_state[self.key] - def set(self, state, value, initiator): + def set(self, state, dict_, value, initiator): if initiator is self: return - self._set_iterable(state, value) + self._set_iterable(state, dict_, value) - def _set_iterable(self, state, iterable, adapter=None): + def _set_iterable(self, state, dict_, iterable, adapter=None): - collection_history = self._modified_event(state) + collection_history = self._modified_event(state, dict_) new_values = list(iterable) if _state_has_identity(state): - old_collection = list(self.get(state)) + old_collection = list(self.get(state, dict_)) else: old_collection = [] @@ -121,7 +121,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl): def delete(self, *args, **kwargs): raise NotImplementedError() - def get_history(self, state, passive=False): + def get_history(self, state, dict_, passive=False): c = self._get_collection_history(state, passive) return attributes.History(c.added_items, c.unchanged_items, c.deleted_items) @@ -136,13 +136,13 @@ class DynamicAttributeImpl(attributes.AttributeImpl): else: return c - def append(self, state, value, initiator, passive=False): + def append(self, state, dict_, value, initiator, passive=False): if initiator is not self: - self.fire_append_event(state, value, initiator) + self.fire_append_event(state, dict_, value, initiator) - def remove(self, state, value, initiator, passive=False): + def remove(self, state, dict_, value, initiator, passive=False): if initiator is not self: - self.fire_remove_event(state, value, initiator) + self.fire_remove_event(state, dict_, value, initiator) class DynCollectionAdapter(object): """the dynamic analogue to orm.collections.CollectionAdapter""" @@ -156,10 +156,10 @@ class DynCollectionAdapter(object): return iter(self.data) def append_with_event(self, item, initiator=None): - self.attr.append(self.state, item, initiator) + self.attr.append(self.state, self.state.dict, item, initiator) def remove_with_event(self, item, initiator=None): - self.attr.remove(self.state, item, initiator) + self.attr.remove(self.state, self.state.dict, item, initiator) def append_without_event(self, item): pass @@ -240,10 +240,10 @@ class AppenderMixin(object): return query def append(self, item): - self.attr.append(attributes.instance_state(self.instance), item, None) + self.attr.append(attributes.instance_state(self.instance), attributes.instance_dict(self.instance), item, None) def remove(self, item): - self.attr.remove(attributes.instance_state(self.instance), item, None) + self.attr.remove(attributes.instance_state(self.instance), attributes.instance_dict(self.instance), item, None) class AppenderQuery(AppenderMixin, Query): diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index 0829d18015..71527c686d 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -12,9 +12,12 @@ from sqlalchemy.orm import attributes class IdentityMap(dict): def __init__(self): - self._mutable_attrs = {} - self.modified = False + self._mutable_attrs = set() + self._modified = set() self._wr = weakref.ref(self) + + def replace(self, state): + raise NotImplementedError() def add(self, state): raise NotImplementedError() @@ -31,28 +34,29 @@ class IdentityMap(dict): def _manage_incoming_state(self, state): state._instance_dict = self._wr - if state.modified: - self.modified = True + if state.modified: + self._modified.add(state) if state.manager.mutable_attributes: - self._mutable_attrs[state] = True + self._mutable_attrs.add(state) def _manage_removed_state(self, state): del state._instance_dict + self._mutable_attrs.discard(state) + self._modified.discard(state) + + def _dirty_states(self): + return self._modified.union(s for s in self._mutable_attrs if s.modified) - if state in self._mutable_attrs: - del self._mutable_attrs[state] - def check_modified(self): """return True if any InstanceStates present have been marked as 'modified'.""" - if not self.modified: - for state in list(self._mutable_attrs): - if state.check_modified(): - return True - else: - return False - else: + if self._modified: return True + else: + for state in self._mutable_attrs: + if state.modified: + return True + return False def has_key(self, key): return key in self @@ -102,6 +106,17 @@ class WeakInstanceDict(IdentityMap): def contains_state(self, state): return dict.get(self, state.key) is state + def replace(self, state): + if dict.__contains__(self, state.key): + existing = dict.__getitem__(self, state.key) + if existing is not state: + self._manage_removed_state(existing) + else: + return + + dict.__setitem__(self, state.key, state) + self._manage_incoming_state(state) + def add(self, state): if state.key in self: if dict.__getitem__(self, state.key) is not state: @@ -176,12 +191,24 @@ class StrongInstanceDict(IdentityMap): def contains_state(self, state): return state.key in self and attributes.instance_state(self[state.key]) is state + def replace(self, state): + if dict.__contains__(self, state.key): + existing = dict.__getitem__(self, state.key) + existing = attributes.instance_state(existing) + if existing is not state: + self._manage_removed_state(existing) + else: + return + + dict.__setitem__(self, state.key, state.obj()) + self._manage_incoming_state(state) + def add(self, state): dict.__setitem__(self, state.key, state.obj()) self._manage_incoming_state(state) def remove(self, state): - if dict.pop(self, state.key) is not state: + if attributes.instance_state(dict.pop(self, state.key)) is not state: raise AssertionError("State %s is not present in this identity map" % state) self._manage_removed_state(state) @@ -191,7 +218,7 @@ class StrongInstanceDict(IdentityMap): self._manage_removed_state(state) def remove_key(self, key): - state = dict.__getitem__(self, key) + state = attributes.instance_state(dict.__getitem__(self, key)) self.remove(state) def prune(self): @@ -205,62 +232,3 @@ class StrongInstanceDict(IdentityMap): self.modified = bool(dirty) return ref_count - len(self) -class IdentityManagedState(attributes.InstanceState): - def _instance_dict(self): - return None - - def modified_event(self, attr, should_copy, previous, passive=False): - attributes.InstanceState.modified_event(self, attr, should_copy, previous, passive) - - instance_dict = self._instance_dict() - if instance_dict: - instance_dict.modified = True - - def _is_really_none(self): - """do a check modified/resurrect. - - This would be called in the extremely rare - race condition that the weakref returned None but - the cleanup handler had not yet established the - __resurrect callable as its replacement. - - """ - if self.check_modified(): - self.obj = self.__resurrect - return self.obj() - else: - return None - - def _cleanup(self, ref): - """weakref callback. - - This method may be called by an asynchronous - gc. - - If the state shows pending changes, the weakref - is replaced by the __resurrect callable which will - re-establish an object reference on next access, - else removes this InstanceState from the owning - identity map, if any. - - """ - if self.check_modified(): - self.obj = self.__resurrect - else: - instance_dict = self._instance_dict() - if instance_dict: - instance_dict.remove(self) - self.dispose() - - def __resurrect(self): - """A substitute for the obj() weakref function which resurrects.""" - - # store strong ref'ed version of the object; will revert - # to weakref when changes are persisted - obj = self.manager.new_instance(state=self) - self.obj = weakref.ref(obj, self._cleanup) - self._strong_obj = obj - obj.__dict__.update(self.dict) - self.dict = obj.__dict__ - self._run_on_load(obj) - return obj diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index d36f51194e..0ac7713058 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -359,7 +359,7 @@ class MapperProperty(object): Callables are of the following form:: - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): # process incoming instance state and given row. the instance is # "new" and was just created upon receipt of this row. # flags is a dictionary containing at least the following @@ -368,7 +368,7 @@ class MapperProperty(object): # result of reading this row # instancekey - identity key of the instance - def existing_execute(state, row, **flags): + def existing_execute(state, dict_, row, **flags): # process incoming instance state and given row. the instance is # "existing" and was created based on a previous row. @@ -427,13 +427,23 @@ class MapperProperty(object): def register_dependencies(self, *args, **kwargs): """Called by the ``Mapper`` in response to the UnitOfWork calling the ``Mapper``'s register_dependencies operation. - Should register with the UnitOfWork all inter-mapper - dependencies as well as dependency processors (see UOW docs - for more details). + Establishes a topological dependency between two mappers + which will affect the order in which mappers persist data. + """ pass + def register_processors(self, *args, **kwargs): + """Called by the ``Mapper`` in response to the UnitOfWork + calling the ``Mapper``'s register_processors operation. + Establishes a processor object between two mappers which + will link data and state between parent/child objects. + + """ + + pass + def is_primary(self): """Return True if this ``MapperProperty``'s mapper is the primary mapper for its class. @@ -939,3 +949,7 @@ class InstrumentationManager(object): def state_getter(self, class_): return lambda instance: getattr(instance, '_default_state') + + def dict_getter(self, class_): + return lambda inst: self.get_instance_dict(class_, inst) + \ No newline at end of file diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index e5dbb4d039..1502060f02 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -23,7 +23,6 @@ deque = __import__('collections').deque from sqlalchemy import sql, util, log, exc as sa_exc from sqlalchemy.sql import expression, visitors, operators, util as sqlutil from sqlalchemy.orm import attributes, exc, sync -from sqlalchemy.orm.identity import IdentityManagedState from sqlalchemy.orm.interfaces import ( MapperProperty, EXT_CONTINUE, PropComparator ) @@ -255,7 +254,8 @@ class Mapper(object): for mapper in self.iterate_to_root(): util.reset_memoized(mapper, '_equivalent_columns') - + util.reset_memoized(mapper, '_sorted_tables') + if self.order_by is False and not self.concrete and self.inherits.order_by is not False: self.order_by = self.inherits.order_by @@ -357,7 +357,6 @@ class Mapper(object): if manager is None: manager = attributes.register_class(self.class_, - instance_state_factory = IdentityManagedState, deferred_scalar_loader = _load_scalar_attributes ) @@ -372,6 +371,8 @@ class Mapper(object): event_registry = manager.events event_registry.add_listener('on_init', _event_on_init) event_registry.add_listener('on_init_failure', _event_on_init_failure) + event_registry.add_listener('on_resurrect', _event_on_resurrect) + for key, method in util.iterate_attributes(self.class_): if isinstance(method, types.FunctionType): if hasattr(method, '__sa_reconstructor__'): @@ -682,7 +683,7 @@ class Mapper(object): for key, prop in l: self._log("initialize prop " + key) - if not prop._compile_started: + if prop.parent is self and not prop._compile_started: prop.init() if prop._compile_finished: @@ -1173,6 +1174,19 @@ class Mapper(object): # persistence + @util.memoized_property + def _sorted_tables(self): + table_to_mapper = {} + for mapper in self.base_mapper.polymorphic_iterator(): + for t in mapper.tables: + table_to_mapper[t] = mapper + + sorted_ = sqlutil.sort_tables(table_to_mapper.iterkeys()) + ret = util.OrderedDict() + for t in sorted_: + ret[t] = table_to_mapper[t] + return ret + def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. @@ -1198,16 +1212,37 @@ class Mapper(object): # if session has a connection callable, # organize individual states with the connection to use for insert/update + tups = [] if 'connection_callable' in uowtransaction.mapper_flush_opts: connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - tups = [(state, _state_mapper(state), connection_callable(self, state.obj()), _state_has_identity(state)) for state in _sort_states(states)] + for state in _sort_states(states): + m = _state_mapper(state) + tups.append( + ( + state, + m, + connection_callable(self, state.obj()), + _state_has_identity(state), + state.key or m._identity_key_from_state(state) + ) + ) else: connection = uowtransaction.transaction.connection(self) - tups = [(state, _state_mapper(state), connection, _state_has_identity(state)) for state in _sort_states(states)] + for state in _sort_states(states): + m = _state_mapper(state) + tups.append( + ( + state, + m, + connection, + _state_has_identity(state), + state.key or m._identity_key_from_state(state) + ) + ) if not postupdate: # call before_XXX extensions - for state, mapper, connection, has_identity in tups: + for state, mapper, connection, has_identity, instance_key in tups: if not has_identity: if 'before_insert' in mapper.extension: mapper.extension.before_insert(mapper, connection, state.obj()) @@ -1215,39 +1250,44 @@ class Mapper(object): if 'before_update' in mapper.extension: mapper.extension.before_update(mapper, connection, state.obj()) - for state, mapper, connection, has_identity in tups: - # detect if we have a "pending" instance (i.e. has no instance_key attached to it), - # and another instance with the same identity key already exists as persistent. convert to an - # UPDATE if so. - instance_key = mapper._identity_key_from_state(state) - if not postupdate and not has_identity and instance_key in uowtransaction.session.identity_map: - instance = uowtransaction.session.identity_map[instance_key] - existing = attributes.instance_state(instance) - if not uowtransaction.is_deleted(existing): - raise exc.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (state_str(state), str(instance_key), state_str(existing))) - if self._should_log_debug: - self._log_debug("detected row switch for identity %s. will update %s, remove %s from transaction" % (instance_key, state_str(state), state_str(existing))) - uowtransaction.set_row_switch(existing) - - table_to_mapper = {} - for mapper in self.base_mapper.polymorphic_iterator(): - for t in mapper.tables: - table_to_mapper[t] = mapper + row_switches = set() + if not postupdate: + for state, mapper, connection, has_identity, instance_key in tups: + # detect if we have a "pending" instance (i.e. has no instance_key attached to it), + # and another instance with the same identity key already exists as persistent. convert to an + # UPDATE if so. + if not has_identity and instance_key in uowtransaction.session.identity_map: + instance = uowtransaction.session.identity_map[instance_key] + existing = attributes.instance_state(instance) + if not uowtransaction.is_deleted(existing): + raise exc.FlushError( + "New instance %s with identity key %s conflicts with persistent instance %s" % + (state_str(state), instance_key, state_str(existing))) + if self._should_log_debug: + self._log_debug( + "detected row switch for identity %s. will update %s, remove %s from transaction", + instance_key, state_str(state), state_str(existing)) + + # remove the "delete" flag from the existing element + uowtransaction.set_row_switch(existing) + row_switches.add(state) + + table_to_mapper = self._sorted_tables - for table in sqlutil.sort_tables(table_to_mapper.iterkeys()): + for table in table_to_mapper.iterkeys(): insert = [] update = [] - for state, mapper, connection, has_identity in tups: + for state, mapper, connection, has_identity, instance_key in tups: if table not in mapper._pks_by_table: continue + pks = mapper._pks_by_table[table] - instance_key = mapper._identity_key_from_state(state) - + if self._should_log_debug: self._log_debug("_save_obj() table '%s' instance %s identity %s" % (table.name, state_str(state), str(instance_key))) - isinsert = not instance_key in uowtransaction.session.identity_map and not postupdate and not has_identity + isinsert = not has_identity and not postupdate and state not in row_switches params = {} value_params = {} @@ -1257,10 +1297,6 @@ class Mapper(object): for col in mapper._cols_by_table[table]: if col is mapper.version_id_col: params[col.key] = 1 - elif col in pks: - value = mapper._get_state_attr_by_column(state, col) - if value is not None: - params[col.key] = value elif mapper.polymorphic_on and mapper.polymorphic_on.shares_lineage(col): if self._should_log_debug: self._log_debug("Using polymorphic identity '%s' for insert column '%s'" % (mapper.polymorphic_identity, col.key)) @@ -1269,6 +1305,10 @@ class Mapper(object): col.server_default is None) or value is not None): params[col.key] = value + elif col in pks: + value = mapper._get_state_attr_by_column(state, col) + if value is not None: + params[col.key] = value else: value = mapper._get_state_attr_by_column(state, col) if ((col.default is None and @@ -1364,7 +1404,7 @@ class Mapper(object): sync.populate(state, m, state, m, m._inherits_equated_pairs) if not postupdate: - for state, mapper, connection, has_identity in tups: + for state, mapper, connection, has_identity, instance_key in tups: # expire readonly attributes readonly = state.unmodified.intersection( @@ -1434,12 +1474,9 @@ class Mapper(object): if 'before_delete' in mapper.extension: mapper.extension.before_delete(mapper, connection, state.obj()) - table_to_mapper = {} - for mapper in self.base_mapper.polymorphic_iterator(): - for t in mapper.tables: - table_to_mapper[t] = mapper + table_to_mapper = self._sorted_tables - for table in reversed(sqlutil.sort_tables(table_to_mapper.iterkeys())): + for table in reversed(table_to_mapper.keys()): delete = {} for state, mapper, connection in tups: if table not in mapper._pks_by_table: @@ -1485,6 +1522,10 @@ class Mapper(object): for dep in self._props.values() + self._dependency_processors: dep.register_dependencies(uowcommit) + def _register_processors(self, uowcommit): + for dep in self._props.values() + self._dependency_processors: + dep.register_processors(uowcommit) + # result set conversion def _instance_processor(self, context, path, adapter, polymorphic_from=None, extension=None, only_load_props=None, refresh_state=None, polymorphic_discriminator=None): @@ -1514,7 +1555,13 @@ class Mapper(object): new_populators = [] existing_populators = [] - def populate_state(state, row, isnew, only_load_props, **flags): + def populate_state(state, dict_, row, isnew, only_load_props, **flags): + if isnew: + if context.options: + state.load_options = context.options + if state.load_options: + state.load_path = context.query._current_path + path + if isnew: if context.options: state.load_options = context.options @@ -1533,7 +1580,7 @@ class Mapper(object): populators = [p for p in populators if p[0] in only_load_props] for key, populator in populators: - populator(state, row, isnew=isnew, **flags) + populator(state, dict_, row, isnew=isnew, **flags) session_identity_map = context.session.identity_map @@ -1573,9 +1620,11 @@ class Mapper(object): if identitykey in session_identity_map: instance = session_identity_map[identitykey] state = attributes.instance_state(instance) + dict_ = attributes.instance_dict(instance) if self._should_log_debug: - self._log_debug("_instance(): using existing instance %s identity %s" % (instance_str(instance), identitykey)) + self._log_debug("_instance(): using existing instance %s identity %s", + instance_str(instance), identitykey) isnew = state.runid != context.runid currentload = not isnew @@ -1592,12 +1641,13 @@ class Mapper(object): # when eager_defaults is True. state = refresh_state instance = state.obj() + dict_ = attributes.instance_dict(instance) isnew = state.runid != context.runid currentload = True loaded_instance = False else: if self._should_log_debug: - self._log_debug("_instance(): identity key %s not in session" % str(identitykey)) + self._log_debug("_instance(): identity key %s not in session", identitykey) if self.allow_null_pks: for x in identitykey[1]: @@ -1625,8 +1675,10 @@ class Mapper(object): instance = self.class_manager.new_instance() if self._should_log_debug: - self._log_debug("_instance(): created new instance %s identity %s" % (instance_str(instance), str(identitykey))) + self._log_debug("_instance(): created new instance %s identity %s", + instance_str(instance), identitykey) + dict_ = attributes.instance_dict(instance) state = attributes.instance_state(instance) state.key = identitykey @@ -1638,12 +1690,12 @@ class Mapper(object): if currentload or populate_existing: if isnew: state.runid = context.runid - context.progress.add(state) + context.progress[state] = dict_ if not populate_instance or \ populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: - populate_state(state, row, isnew, only_load_props) + populate_state(state, dict_, row, isnew, only_load_props) else: # populate attributes on non-loading instances which have been expired @@ -1652,16 +1704,16 @@ class Mapper(object): if state in context.partials: isnew = False - attrs = context.partials[state] + (d_, attrs) = context.partials[state] else: isnew = True attrs = state.unloaded - context.partials[state] = attrs #<-- allow query.instances to commit the subset of attrs + context.partials[state] = (dict_, attrs) #<-- allow query.instances to commit the subset of attrs if not populate_instance or \ populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: - populate_state(state, row, isnew, attrs, instancekey=identitykey) + populate_state(state, dict_, row, isnew, attrs, instancekey=identitykey) if loaded_instance: state._run_on_load(instance) @@ -1759,6 +1811,14 @@ def _event_on_init_failure(state, instance, args, kwargs): instrumenting_mapper, instrumenting_mapper.class_, state.manager.events.original_init, instance, args, kwargs) +def _event_on_resurrect(state, instance): + # re-populate the primary key elements + # of the dict based on the mapping. + instrumenting_mapper = state.manager.info[_INSTRUMENTOR] + for col, val in zip(instrumenting_mapper.primary_key, state.key[1]): + instrumenting_mapper._set_state_attr_by_column(state, col, val) + + def _sort_states(states): return sorted(states, key=operator.attrgetter('sort_key')) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 398cbe5d98..5605cdcd1e 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -96,13 +96,13 @@ class ColumnProperty(StrategizedProperty): return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns) def getattr(self, state, column): - return state.get_impl(self.key).get(state) + return state.get_impl(self.key).get(state, state.dict) def getcommitted(self, state, column, passive=False): - return state.get_impl(self.key).get_committed_value(state, passive=passive) + return state.get_impl(self.key).get_committed_value(state, state.dict, passive=passive) def setattr(self, state, value, column): - state.get_impl(self.key).set(state, value, None) + state.get_impl(self.key).set(state, state.dict, value, None) def merge(self, session, source, dest, dont_load, _recursive): value = attributes.instance_state(source).value_as_iterable( @@ -159,7 +159,7 @@ class CompositeProperty(ColumnProperty): super(ColumnProperty, self).do_init() def getattr(self, state, column): - obj = state.get_impl(self.key).get(state) + obj = state.get_impl(self.key).get(state, state.dict) return self.get_col_value(column, obj) def getcommitted(self, state, column, passive=False): @@ -168,7 +168,7 @@ class CompositeProperty(ColumnProperty): def setattr(self, state, value, column): - obj = state.get_impl(self.key).get(state) + obj = state.get_impl(self.key).get(state, state.dict) if obj is None: obj = self.composite_class(*[None for c in self.columns]) state.get_impl(self.key).set(state, obj, None) @@ -635,7 +635,7 @@ class RelationProperty(StrategizedProperty): return source_state = attributes.instance_state(source) - dest_state = attributes.instance_state(dest) + dest_state, dest_dict = attributes.instance_state(dest), attributes.instance_dict(dest) if not "merge" in self.cascade: dest_state.expire_attributes([self.key]) @@ -658,7 +658,7 @@ class RelationProperty(StrategizedProperty): for c in dest_list: coll.append_without_event(c) else: - getattr(dest.__class__, self.key).impl._set_iterable(dest_state, dest_list) + getattr(dest.__class__, self.key).impl._set_iterable(dest_state, dest_dict, dest_list) else: current = instances[0] if current is not None: @@ -839,8 +839,8 @@ class RelationProperty(StrategizedProperty): if self._foreign_keys: raise sa_exc.ArgumentError("Could not determine relation direction for " "primaryjoin condition '%s', on relation %s. " - "Are the columns in 'foreign_keys' present within the given " - "join condition ?" % (self.primaryjoin, self)) + "Do the columns in 'foreign_keys' represent only the 'foreign' columns " + "in this join condition ?" % (self.primaryjoin, self)) else: raise sa_exc.ArgumentError("Could not determine relation direction for " "primaryjoin condition '%s', on relation %s. " @@ -1119,6 +1119,10 @@ class RelationProperty(StrategizedProperty): if not self.viewonly: self._dependency_processor.register_dependencies(uowcommit) + def register_processors(self, uowcommit): + if not self.viewonly: + self._dependency_processor.register_processors(uowcommit) + PropertyLoader = RelationProperty log.class_logger(RelationProperty) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 533ec9aa52..be40b08c65 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1330,7 +1330,7 @@ class Query(object): rowtuple.keys = labels.keys while True: - context.progress = set() + context.progress = {} context.partials = {} if self._yield_per: @@ -1354,13 +1354,13 @@ class Query(object): rows = filter(rows) if context.refresh_state and self._only_load_props and context.refresh_state in context.progress: - context.refresh_state.commit(self._only_load_props) - context.progress.remove(context.refresh_state) + context.refresh_state.commit(context.refresh_state.dict, self._only_load_props) + context.progress.pop(context.refresh_state) session._finalize_loaded(context.progress) - for ii, attrs in context.partials.iteritems(): - ii.commit(attrs) + for ii, (dict_, attrs) in context.partials.iteritems(): + ii.commit(dict_, attrs) for row in rows: yield row @@ -1687,14 +1687,14 @@ class Query(object): evaluated_keys = value_evaluators.keys() if issubclass(cls, target_cls) and eval_condition(obj): - state = attributes.instance_state(obj) + state, dict_ = attributes.instance_state(obj), attributes.instance_dict(obj) # only evaluate unmodified attributes to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: - state.dict[key] = value_evaluators[key](obj) + dict_[key] = value_evaluators[key](obj) - state.commit(list(to_evaluate)) + state.commit(dict_, list(to_evaluate)) # expire attributes with pending changes (there was no autoflush, so they are overwritten) state.expire_attributes(set(evaluated_keys).difference(to_evaluate)) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 1e3a750d95..cbfb0c1d64 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -12,7 +12,7 @@ import sqlalchemy.exceptions as sa_exc from sqlalchemy import util, sql, engine, log from sqlalchemy.sql import util as sql_util, expression from sqlalchemy.orm import ( - SessionExtension, attributes, exc, query, unitofwork, util as mapperutil, + SessionExtension, attributes, exc, query, unitofwork, util as mapperutil, state ) from sqlalchemy.orm.util import object_mapper as _object_mapper from sqlalchemy.orm.util import class_mapper as _class_mapper @@ -299,14 +299,14 @@ class SessionTransaction(object): self.session._expunge_state(s) for s in self.session.identity_map.all_states(): - _expire_state(s, None) + _expire_state(s, None, instance_dict=self.session.identity_map) def _remove_snapshot(self): assert self._is_transaction_boundary if not self.nested and self.session.expire_on_commit: for s in self.session.identity_map.all_states(): - _expire_state(s, None) + _expire_state(s, None, instance_dict=self.session.identity_map) def _connection_for_bind(self, bind): self._assert_is_active() @@ -899,8 +899,8 @@ class Session(object): self.flush() def _finalize_loaded(self, states): - for state in states: - state.commit_all() + for state, dict_ in states.items(): + state.commit_all(dict_, self.identity_map) def refresh(self, instance, attribute_names=None): """Refresh the attributes on the given instance. @@ -935,7 +935,7 @@ class Session(object): """Expires all persistent instances within this Session.""" for state in self.identity_map.all_states(): - _expire_state(state, None) + _expire_state(state, None, instance_dict=self.identity_map) def expire(self, instance, attribute_names=None): """Expire the attributes on an instance. @@ -956,14 +956,14 @@ class Session(object): raise exc.UnmappedInstanceError(instance) self._validate_persistent(state) if attribute_names: - _expire_state(state, attribute_names=attribute_names) + _expire_state(state, attribute_names=attribute_names, instance_dict=self.identity_map) else: # pre-fetch the full cascade since the expire is going to # remove associations cascaded = list(_cascade_state_iterator('refresh-expire', state)) - _expire_state(state, None) + _expire_state(state, None, instance_dict=self.identity_map) for (state, m, o) in cascaded: - _expire_state(state, None) + _expire_state(state, None, instance_dict=self.identity_map) def prune(self): """Remove unreferenced instances cached in the identity map. @@ -1020,12 +1020,10 @@ class Session(object): # primary key switch self.identity_map.remove(state) state.key = instance_key - - if state.key in self.identity_map and not self.identity_map.contains_state(state): - self.identity_map.remove_key(state.key) - self.identity_map.add(state) - state.commit_all() - + + self.identity_map.replace(state) + state.commit_all(state.dict, self.identity_map) + # remove from new last, might be the last strong ref if state in self._new: if self._enable_transaction_accounting and self.transaction: @@ -1213,7 +1211,7 @@ class Session(object): prop.merge(self, instance, merged, dont_load, _recursive) if dont_load: - attributes.instance_state(merged).commit_all() # remove any history + attributes.instance_state(merged).commit_all(attributes.instance_dict(merged), self.identity_map) # remove any history if new_instance: merged_state._run_on_load(merged) @@ -1362,13 +1360,12 @@ class Session(object): not self._deleted and not self._new): return - dirty = self._dirty_states if not dirty and not self._deleted and not self._new: - self.identity_map.modified = False + self.identity_map._modified.clear() return - flush_context = UOWTransaction(self) + flush_context = UOWTransaction(self) if self.extensions: for ext in self.extensions: @@ -1391,15 +1388,19 @@ class Session(object): raise exc.UnmappedInstanceError(o) objset.add(state) else: - # or just everything - objset = set(self.identity_map.all_states()).union(new) + objset = None # store objects whose fate has been decided processed = set() # put all saves/updates into the flush context. detect top-level # orphans and throw them into deleted. - for state in new.union(dirty).intersection(objset).difference(deleted): + if objset: + proc = new.union(dirty).intersection(objset).difference(deleted) + else: + proc = new.union(dirty).difference(deleted) + + for state in proc: is_orphan = _state_mapper(state)._is_orphan(state) if is_orphan and not _state_has_identity(state): path = ", nor ".join( @@ -1415,7 +1416,11 @@ class Session(object): processed.add(state) # put all remaining deletes into the flush context. - for state in deleted.intersection(objset).difference(processed): + if objset: + proc = deleted.intersection(objset).difference(processed) + else: + proc = deleted.difference(processed) + for state in proc: flush_context.register_object(state, isdelete=True) if len(flush_context.tasks) == 0: @@ -1435,9 +1440,13 @@ class Session(object): flush_context.finalize_flush_changes() - if not objects: - self.identity_map.modified = False - + # useful assertions: + #if not objects: + # assert not self.identity_map._modified + #else: + # assert self.identity_map._modified == self.identity_map._modified.difference(objects) + #self.identity_map._modified.clear() + for ext in self.extensions: ext.after_flush_postexec(self, flush_context) @@ -1486,10 +1495,7 @@ class Session(object): those that were possibly deleted. """ - return util.IdentitySet( - [state - for state in self.identity_map.all_states() - if state.check_modified()]) + return self.identity_map._dirty_states() @property def dirty(self): @@ -1528,7 +1534,7 @@ class Session(object): return util.IdentitySet(self._new.values()) -_expire_state = attributes.InstanceState.expire_attributes +_expire_state = state.InstanceState.expire_attributes UOWEventHandler = unitofwork.UOWEventHandler @@ -1548,16 +1554,19 @@ def _cascade_unknown_state_iterator(cascade, state, **kwargs): yield _state_for_unknown_persistence_instance(o), m def _state_for_unsaved_instance(instance, create=False): - manager = attributes.manager_of_class(instance.__class__) - if manager is None: + try: + state = attributes.instance_state(instance) + except AttributeError: raise exc.UnmappedInstanceError(instance) - if manager.has_state(instance): - state = manager.state_of(instance) + if state: if state.key is not None: raise sa_exc.InvalidRequestError( "Instance '%s' is already persistent" % mapperutil.state_str(state)) elif create: + manager = attributes.manager_of_class(instance.__class__) + if manager is None: + raise exc.UnmappedInstanceError(instance) state = manager.setup_instance(instance) else: raise exc.UnmappedInstanceError(instance) diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py new file mode 100644 index 0000000000..1b73a1bb62 --- /dev/null +++ b/lib/sqlalchemy/orm/state.py @@ -0,0 +1,441 @@ +from sqlalchemy.util import EMPTY_SET +import weakref +from sqlalchemy import util +from sqlalchemy.orm.attributes import PASSIVE_NORESULT, PASSIVE_OFF, NEVER_SET, NO_VALUE, manager_of_class, ATTR_WAS_SET +from sqlalchemy.orm import attributes +from sqlalchemy.orm import interfaces + +class InstanceState(object): + """tracks state information at the instance level.""" + + session_id = None + key = None + runid = None + expired_attributes = EMPTY_SET + load_options = EMPTY_SET + load_path = () + insert_order = None + mutable_dict = None + + def __init__(self, obj, manager): + self.class_ = obj.__class__ + self.manager = manager + self.obj = weakref.ref(obj, self._cleanup) + self.modified = False + self.callables = {} + self.expired = False + self.committed_state = {} + self.pending = {} + self.parents = {} + + def detach(self): + if self.session_id: + del self.session_id + + def dispose(self): + if self.session_id: + del self.session_id + del self.obj + + def _cleanup(self, ref): + instance_dict = self._instance_dict() + if instance_dict: + instance_dict.remove(self) + self.dispose() + + def obj(self): + return None + + @property + def dict(self): + o = self.obj() + if o is not None: + return attributes.instance_dict(o) + else: + return {} + + @property + def sort_key(self): + return self.key and self.key[1] or (self.insert_order, ) + + def check_modified(self): + # TODO: deprecate + return self.modified + + def initialize_instance(*mixed, **kwargs): + self, instance, args = mixed[0], mixed[1], mixed[2:] + manager = self.manager + + for fn in manager.events.on_init: + fn(self, instance, args, kwargs) + + # LESSTHANIDEAL: + # adjust for the case where the InstanceState was created before + # mapper compilation, and this actually needs to be a MutableAttrInstanceState + if manager.mutable_attributes and self.__class__ is not MutableAttrInstanceState: + self.__class__ = MutableAttrInstanceState + self.obj = weakref.ref(self.obj(), self._cleanup) + self.mutable_dict = {} + + try: + return manager.events.original_init(*mixed[1:], **kwargs) + except: + for fn in manager.events.on_init_failure: + fn(self, instance, args, kwargs) + raise + + def get_history(self, key, **kwargs): + return self.manager.get_impl(key).get_history(self, self.dict, **kwargs) + + def get_impl(self, key): + return self.manager.get_impl(key) + + def get_pending(self, key): + if key not in self.pending: + self.pending[key] = PendingCollection() + return self.pending[key] + + def value_as_iterable(self, key, passive=PASSIVE_OFF): + """return an InstanceState attribute as a list, + regardless of it being a scalar or collection-based + attribute. + + returns None if passive is not PASSIVE_OFF and the getter returns + PASSIVE_NORESULT. + """ + + impl = self.get_impl(key) + dict_ = self.dict + x = impl.get(self, dict_, passive=passive) + if x is PASSIVE_NORESULT: + return None + elif hasattr(impl, 'get_collection'): + return impl.get_collection(self, dict_, x, passive=passive) + elif isinstance(x, list): + return x + else: + return [x] + + def _run_on_load(self, instance): + self.manager.events.run('on_load', instance) + + def __getstate__(self): + return {'key': self.key, + 'committed_state': self.committed_state, + 'pending': self.pending, + 'parents': self.parents, + 'modified': self.modified, + 'expired':self.expired, + 'load_options':self.load_options, + 'load_path':interfaces.serialize_path(self.load_path), + 'instance': self.obj(), + 'expired_attributes':self.expired_attributes, + 'callables': self.callables} + + def __setstate__(self, state): + self.committed_state = state['committed_state'] + self.parents = state['parents'] + self.key = state['key'] + self.session_id = None + self.pending = state['pending'] + self.modified = state['modified'] + self.obj = weakref.ref(state['instance']) + self.load_options = state['load_options'] or EMPTY_SET + self.load_path = interfaces.deserialize_path(state['load_path']) + self.class_ = self.obj().__class__ + self.manager = manager_of_class(self.class_) + self.callables = state['callables'] + self.runid = None + self.expired = state['expired'] + self.expired_attributes = state['expired_attributes'] + + def initialize(self, key): + self.manager.get_impl(key).initialize(self, self.dict) + + def set_callable(self, key, callable_): + self.dict.pop(key, None) + self.callables[key] = callable_ + + def __call__(self): + """__call__ allows the InstanceState to act as a deferred + callable for loading expired attributes, which is also + serializable (picklable). + + """ + unmodified = self.unmodified + class_manager = self.manager + class_manager.deferred_scalar_loader(self, [ + attr.impl.key for attr in class_manager.attributes if + attr.impl.accepts_scalar_loader and + attr.impl.key in self.expired_attributes and + attr.impl.key in unmodified + ]) + for k in self.expired_attributes: + self.callables.pop(k, None) + del self.expired_attributes + return ATTR_WAS_SET + + @property + def unmodified(self): + """a set of keys which have no uncommitted changes""" + + return set(self.manager).difference(self.committed_state) + + @property + def unloaded(self): + """a set of keys which do not have a loaded value. + + This includes expired attributes and any other attribute that + was never populated or modified. + + """ + return set( + key for key in self.manager.iterkeys() + if key not in self.committed_state and key not in self.dict) + + def expire_attributes(self, attribute_names, instance_dict=None): + self.expired_attributes = set(self.expired_attributes) + + if attribute_names is None: + attribute_names = self.manager.keys() + self.expired = True + if self.modified: + if not instance_dict: + instance_dict = self._instance_dict() + if instance_dict: + instance_dict._modified.discard(self) + else: + instance_dict._modified.discard(self) + + self.modified = False + filter_deferred = True + else: + filter_deferred = False + dict_ = self.dict + + for key in attribute_names: + impl = self.manager[key].impl + if not filter_deferred or \ + not impl.dont_expire_missing or \ + key in dict_: + self.expired_attributes.add(key) + if impl.accepts_scalar_loader: + self.callables[key] = self + dict_.pop(key, None) + self.pending.pop(key, None) + self.committed_state.pop(key, None) + if self.mutable_dict: + self.mutable_dict.pop(key, None) + + def reset(self, key, dict_): + """remove the given attribute and any callables associated with it.""" + + dict_.pop(key, None) + self.callables.pop(key, None) + + def _instance_dict(self): + return None + + def _is_really_none(self): + return self.obj() + + def modified_event(self, dict_, attr, should_copy, previous, passive=PASSIVE_OFF): + needs_committed = attr.key not in self.committed_state + + if needs_committed: + if previous is NEVER_SET: + if passive: + if attr.key in dict_: + previous = dict_[attr.key] + else: + previous = attr.get(self, dict_) + + if should_copy and previous not in (None, NO_VALUE, NEVER_SET): + previous = attr.copy(previous) + + if needs_committed: + self.committed_state[attr.key] = previous + + if not self.modified: + instance_dict = self._instance_dict() + if instance_dict: + instance_dict._modified.add(self) + + self.modified = True + self._strong_obj = self.obj() + + def commit(self, dict_, keys): + """Commit attributes. + + This is used by a partial-attribute load operation to mark committed + those attributes which were refreshed from the database. + + Attributes marked as "expired" can potentially remain "expired" after + this step if a value was not populated in state.dict. + + """ + class_manager = self.manager + for key in keys: + if key in dict_ and key in class_manager.mutable_attributes: + class_manager[key].impl.commit_to_state(self, dict_, self.committed_state) + else: + self.committed_state.pop(key, None) + + self.expired = False + # unexpire attributes which have loaded + for key in self.expired_attributes.intersection(keys): + if key in dict_: + self.expired_attributes.remove(key) + self.callables.pop(key, None) + + def commit_all(self, dict_, instance_dict=None): + """commit all attributes unconditionally. + + This is used after a flush() or a full load/refresh + to remove all pending state from the instance. + + - all attributes are marked as "committed" + - the "strong dirty reference" is removed + - the "modified" flag is set to False + - any "expired" markers/callables are removed. + + Attributes marked as "expired" can potentially remain "expired" after this step + if a value was not populated in state.dict. + + """ + + self.committed_state = {} + self.pending = {} + + # unexpire attributes which have loaded + if self.expired_attributes: + for key in self.expired_attributes.intersection(dict_): + self.callables.pop(key, None) + self.expired_attributes.difference_update(dict_) + + for key in self.manager.mutable_attributes: + if key in dict_: + self.manager[key].impl.commit_to_state(self, dict_, self.committed_state) + + if instance_dict and self.modified: + instance_dict._modified.discard(self) + + self.modified = self.expired = False + self._strong_obj = None + +class MutableAttrInstanceState(InstanceState): + def __init__(self, obj, manager): + self.mutable_dict = {} + InstanceState.__init__(self, obj, manager) + + def _get_modified(self, dict_=None): + if self.__dict__.get('modified', False): + return True + else: + if dict_ is None: + dict_ = self.dict + for key in self.manager.mutable_attributes: + if self.manager[key].impl.check_mutable_modified(self, dict_): + return True + else: + return False + + def _set_modified(self, value): + self.__dict__['modified'] = value + + modified = property(_get_modified, _set_modified) + + @property + def unmodified(self): + """a set of keys which have no uncommitted changes""" + + dict_ = self.dict + return set( + key for key in self.manager.iterkeys() + if (key not in self.committed_state or + (key in self.manager.mutable_attributes and + not self.manager[key].impl.check_mutable_modified(self, dict_)))) + + def _is_really_none(self): + """do a check modified/resurrect. + + This would be called in the extremely rare + race condition that the weakref returned None but + the cleanup handler had not yet established the + __resurrect callable as its replacement. + + """ + if self.modified: + self.obj = self.__resurrect + return self.obj() + else: + return None + + def reset(self, key, dict_): + self.mutable_dict.pop(key, None) + InstanceState.reset(self, key, dict_) + + def _cleanup(self, ref): + """weakref callback. + + This method may be called by an asynchronous + gc. + + If the state shows pending changes, the weakref + is replaced by the __resurrect callable which will + re-establish an object reference on next access, + else removes this InstanceState from the owning + identity map, if any. + + """ + if self._get_modified(self.mutable_dict): + self.obj = self.__resurrect + else: + instance_dict = self._instance_dict() + if instance_dict: + instance_dict.remove(self) + self.dispose() + + def __resurrect(self): + """A substitute for the obj() weakref function which resurrects.""" + + # store strong ref'ed version of the object; will revert + # to weakref when changes are persisted + + obj = self.manager.new_instance(state=self) + self.obj = weakref.ref(obj, self._cleanup) + self._strong_obj = obj + obj.__dict__.update(self.mutable_dict) + + # re-establishes identity attributes from the key + self.manager.events.run('on_resurrect', self, obj) + + # TODO: don't really think we should run this here. + # resurrect is only meant to preserve the minimal state needed to + # do an UPDATE, not to produce a fully usable object + self._run_on_load(obj) + + return obj + +class PendingCollection(object): + """A writable placeholder for an unloaded collection. + + Stores items appended to and removed from a collection that has not yet + been loaded. When the collection is loaded, the changes stored in + PendingCollection are applied to it to produce the final result. + + """ + def __init__(self): + self.deleted_items = util.IdentitySet() + self.added_items = util.OrderedIdentitySet() + + def append(self, value): + if value in self.deleted_items: + self.deleted_items.remove(value) + self.added_items.add(value) + + def remove(self, value): + if value in self.added_items: + self.added_items.remove(value) + self.deleted_items.add(value) + diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 1aeb311e1c..20cbb8f4dc 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -115,8 +115,8 @@ class ColumnLoader(LoaderStrategy): if adapter: col = adapter.columns[col] if col in row: - def new_execute(state, row, **flags): - state.dict[key] = row[col] + def new_execute(state, dict_, row, **flags): + dict_[key] = row[col] if self._should_log_debug: new_execute = self.debug_callable(new_execute, self.logger, @@ -125,7 +125,7 @@ class ColumnLoader(LoaderStrategy): ) return (new_execute, None) else: - def new_execute(state, row, isnew, **flags): + def new_execute(state, dict_, row, isnew, **flags): if isnew: state.expire_attributes([key]) if self._should_log_debug: @@ -171,15 +171,15 @@ class CompositeColumnLoader(ColumnLoader): columns = [adapter.columns[c] for c in columns] for c in columns: if c not in row: - def new_execute(state, row, isnew, **flags): + def new_execute(state, dict_, row, isnew, **flags): if isnew: state.expire_attributes([key]) if self._should_log_debug: self.logger.debug("%s deferring load" % self) return (new_execute, None) else: - def new_execute(state, row, **flags): - state.dict[key] = composite_class(*[row[c] for c in columns]) + def new_execute(state, dict_, row, **flags): + dict_[key] = composite_class(*[row[c] for c in columns]) if self._should_log_debug: new_execute = self.debug_callable(new_execute, self.logger, @@ -202,13 +202,13 @@ class DeferredColumnLoader(LoaderStrategy): return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, path, mapper, row, adapter) elif not self.is_class_level: - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): state.set_callable(self.key, LoadDeferredColumns(state, self.key)) else: - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): # reset state on the key so that deferred callables # fire off on next access. - state.reset(self.key) + state.reset(self.key, dict_) if self._should_log_debug: new_execute = self.debug_callable(new_execute, self.logger, None, @@ -340,7 +340,7 @@ class NoLoader(AbstractRelationLoader): ) def create_row_processor(self, selectcontext, path, mapper, row, adapter): - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): self._init_instance_attribute(state) if self._should_log_debug: @@ -437,7 +437,7 @@ class LazyLoader(AbstractRelationLoader): def create_row_processor(self, selectcontext, path, mapper, row, adapter): if not self.is_class_level: - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader, # which will override the class-level behavior. # this currently only happens when using a "lazyload" option on a "no load" attribute - @@ -451,11 +451,11 @@ class LazyLoader(AbstractRelationLoader): return (new_execute, None) else: - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): # we are the primary manager for this attribute on this class - reset its per-instance attribute state, # so that the class-level lazy loader is executed when next referenced on this instance. # this is needed in populate_existing() types of scenarios to reset any existing state. - state.reset(self.key) + state.reset(self.key, dict_) if self._should_log_debug: new_execute = self.debug_callable(new_execute, self.logger, None, @@ -735,24 +735,24 @@ class EagerLoader(AbstractRelationLoader): _instance = self.mapper._instance_processor(context, path + (self.mapper.base_mapper,), eager_adapter) if not self.uselist: - def execute(state, row, isnew, **flags): + def execute(state, dict_, row, isnew, **flags): if isnew: # set a scalar object instance directly on the # parent object, bypassing InstrumentedAttribute # event handlers. - state.dict[key] = _instance(row, None) + dict_[key] = _instance(row, None) else: # call _instance on the row, even though the object has been created, # so that we further descend into properties _instance(row, None) else: - def execute(state, row, isnew, **flags): + def execute(state, dict_, row, isnew, **flags): if isnew or (state, key) not in context.attributes: # appender_key can be absent from context.attributes with isnew=False # when self-referential eager loading is used; the same instance may be present # in two distinct sets of result columns - collection = attributes.init_state_collection(state, key) + collection = attributes.init_state_collection(state, dict_, key) appender = util.UniqueAppender(collection, 'append_without_event') context.attributes[(state, key)] = appender diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index dd979e1a80..c12f17aff5 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -50,26 +50,18 @@ def populate_dict(source, source_mapper, dict_, synchronize_pairs): dict_[r.key] = value -def source_changes(uowcommit, source, source_mapper, synchronize_pairs): +def source_modified(uowcommit, source, source_mapper, synchronize_pairs): + """return true if the source object has changes from an old to a new value on the given + synchronize pairs + + """ for l, r in synchronize_pairs: try: prop = source_mapper._get_col_to_prop(l) except exc.UnmappedColumnError: _raise_col_to_prop(False, source_mapper, l, None, r) history = uowcommit.get_attribute_history(source, prop.key, passive=True) - if history.has_changes(): - return True - else: - return False - -def dest_changes(uowcommit, dest, dest_mapper, synchronize_pairs): - for l, r in synchronize_pairs: - try: - prop = dest_mapper._get_col_to_prop(r) - except exc.UnmappedColumnError: - _raise_col_to_prop(True, None, l, dest_mapper, r) - history = uowcommit.get_attribute_history(dest, prop.key, passive=True) - if history.has_changes(): + if len(history.deleted): return True else: return False diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 4ac9c765e0..da26c8d7b3 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -96,6 +96,8 @@ class UOWTransaction(object): # information. self.attributes = {} + self.processors = set() + def get_attribute_history(self, state, key, passive=True): hashkey = ("history", state, key) @@ -119,6 +121,7 @@ class UOWTransaction(object): return history.as_state() def register_object(self, state, isdelete=False, listonly=False, postupdate=False, post_update_cols=None): + # if object is not in the overall session, do nothing if not self.session._contains_state(state): if self._should_log_debug: @@ -136,6 +139,16 @@ class UOWTransaction(object): else: task.append(state, listonly=listonly, isdelete=isdelete) + # ensure the mapper for this object has had its + # DependencyProcessors added. + if mapper not in self.processors: + mapper._register_processors(self) + self.processors.add(mapper) + + if mapper.base_mapper not in self.processors: + mapper.base_mapper._register_processors(self) + self.processors.add(mapper.base_mapper) + def set_row_switch(self, state): """mark a deleted object as a 'row switch'. @@ -147,7 +160,7 @@ class UOWTransaction(object): task = self.get_task_by_mapper(mapper) taskelement = task._objects[state] taskelement.isdelete = "rowswitch" - + def is_deleted(self, state): """return true if the given state is marked as deleted within this UOWTransaction.""" @@ -201,9 +214,9 @@ class UOWTransaction(object): self.dependencies.add((mapper, dependency)) def register_processor(self, mapper, processor, mapperfrom): - """register a dependency processor, corresponding to dependencies between - the two given mappers. - + """register a dependency processor, corresponding to + operations which occur between two mappers. + """ # correct for primary mapper mapper = mapper.primary_mapper() diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 4ecc7a0678..3fd95642e6 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1608,7 +1608,7 @@ class ColumnElement(ClauseElement, _CompareMixin): def shares_lineage(self, othercolumn): """Return True if the given ``ColumnElement`` has a common ancestor to this ``ColumnElement``.""" - return len(self.proxy_set.intersection(othercolumn.proxy_set)) > 0 + return bool(self.proxy_set.intersection(othercolumn.proxy_set)) def _make_proxy(self, selectable, name=None): """Create a new ``ColumnElement`` representing this diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 36357faf50..f1f329b5e2 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -343,14 +343,14 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_re return if consider_as_foreign_keys: - if binary.left in consider_as_foreign_keys: + if binary.left in consider_as_foreign_keys and (binary.right is binary.left or binary.right not in consider_as_foreign_keys): pairs.append((binary.right, binary.left)) - elif binary.right in consider_as_foreign_keys: + elif binary.right in consider_as_foreign_keys and (binary.left is binary.right or binary.left not in consider_as_foreign_keys): pairs.append((binary.left, binary.right)) elif consider_as_referenced_keys: - if binary.left in consider_as_referenced_keys: + if binary.left in consider_as_referenced_keys and (binary.right is binary.left or binary.right not in consider_as_referenced_keys): pairs.append((binary.left, binary.right)) - elif binary.right in consider_as_referenced_keys: + elif binary.right in consider_as_referenced_keys and (binary.left is binary.right or binary.left not in consider_as_referenced_keys): pairs.append((binary.right, binary.left)) else: if isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column): diff --git a/test/dialect/sqlite.py b/test/dialect/sqlite.py index f114619497..23c0389550 100644 --- a/test/dialect/sqlite.py +++ b/test/dialect/sqlite.py @@ -11,6 +11,28 @@ from testlib import * class TestTypes(TestBase, AssertsExecutionResults): __only_on__ = 'sqlite' + def test_boolean(self): + """Test that the boolean only treats 1 as True + + """ + + meta = MetaData(testing.db) + t = Table('bool_table', meta, + Column('id', Integer, primary_key=True), + Column('boo', sqlite.SLBoolean)) + + try: + meta.create_all() + testing.db.execute("INSERT INTO bool_table (id, boo) VALUES (1, 'false');") + testing.db.execute("INSERT INTO bool_table (id, boo) VALUES (2, 'true');") + testing.db.execute("INSERT INTO bool_table (id, boo) VALUES (3, '1');") + testing.db.execute("INSERT INTO bool_table (id, boo) VALUES (4, '0');") + testing.db.execute("INSERT INTO bool_table (id, boo) VALUES (5, 1);") + testing.db.execute("INSERT INTO bool_table (id, boo) VALUES (6, 0);") + assert t.select(t.c.boo).order_by(t.c.id).execute().fetchall() == [(3, True,), (5, True,)] + finally: + meta.drop_all() + def test_string_dates_raise(self): self.assertRaises(TypeError, testing.db.execute, select([1]).where(bindparam("date", type_=Date)), date=str(datetime.date(2007, 10, 30))) diff --git a/test/orm/attributes.py b/test/orm/attributes.py index 46d944cbc3..0f15d5136f 100644 --- a/test/orm/attributes.py +++ b/test/orm/attributes.py @@ -38,7 +38,7 @@ class AttributesTest(_base.ORMTest): u.email_address = 'lala@123.com' self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') - attributes.instance_state(u).commit_all() + attributes.instance_state(u).commit_all(attributes.instance_dict(u)) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') u.user_name = 'heythere' @@ -158,7 +158,7 @@ class AttributesTest(_base.ORMTest): eq_(f.a, None) eq_(f.b, 12) - attributes.instance_state(f).commit_all() + attributes.instance_state(f).commit_all(attributes.instance_dict(f)) eq_(f.a, None) eq_(f.b, 12) @@ -205,7 +205,7 @@ class AttributesTest(_base.ORMTest): u.addresses.append(a) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') - u, attributes.instance_state(a).commit_all() + u, attributes.instance_state(a).commit_all(attributes.instance_dict(a)) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') u.user_name = 'heythere' @@ -272,7 +272,7 @@ class AttributesTest(_base.ORMTest): p1 = Post() attributes.instance_state(b).set_callable('posts', lambda:[p1]) attributes.instance_state(p1).set_callable('blog', lambda:b) - p1, attributes.instance_state(b).commit_all() + p1, attributes.instance_state(b).commit_all(attributes.instance_dict(b)) # no orphans (called before the lazy loaders fire off) assert attributes.has_parent(Blog, p1, 'posts', optimistic=True) @@ -353,7 +353,7 @@ class AttributesTest(_base.ORMTest): x = Bar() x.element = el eq_(attributes.get_history(attributes.instance_state(x), 'element'), ([el], (), ())) - attributes.instance_state(x).commit_all() + attributes.instance_state(x).commit_all(attributes.instance_dict(x)) (added, unchanged, deleted) = attributes.get_history(attributes.instance_state(x), 'element') assert added == () @@ -381,7 +381,7 @@ class AttributesTest(_base.ORMTest): attributes.register_attribute(Bar, 'id', uselist=False, useobject=True) x = Foo() - attributes.instance_state(x).commit_all() + attributes.instance_state(x).commit_all(attributes.instance_dict(x)) x.col2.append(bar4) eq_(attributes.get_history(attributes.instance_state(x), 'col2'), ([bar4], [bar1, bar2, bar3], [])) @@ -427,7 +427,7 @@ class AttributesTest(_base.ORMTest): attributes.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x], mutable_scalars=True, useobject=False) x = Foo() x.element = ['one', 'two', 'three'] - attributes.instance_state(x).commit_all() + attributes.instance_state(x).commit_all(attributes.instance_dict(x)) x.element[1] = 'five' assert attributes.instance_state(x).check_modified() @@ -437,7 +437,7 @@ class AttributesTest(_base.ORMTest): attributes.register_attribute(Foo, 'element', uselist=False, useobject=False) x = Foo() x.element = ['one', 'two', 'three'] - attributes.instance_state(x).commit_all() + attributes.instance_state(x).commit_all(attributes.instance_dict(x)) x.element[1] = 'five' assert not attributes.instance_state(x).check_modified() @@ -699,8 +699,8 @@ class PendingBackrefTest(_base.ORMTest): b = Blog("blog 1") p1.blog = b - attributes.instance_state(b).commit_all() - attributes.instance_state(p1).commit_all() + attributes.instance_state(b).commit_all(attributes.instance_dict(b)) + attributes.instance_state(p1).commit_all(attributes.instance_dict(p1)) assert b.posts == [Post("post 1")] class HistoryTest(_base.ORMTest): @@ -713,17 +713,17 @@ class HistoryTest(_base.ORMTest): attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False) f = Foo() - eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None) + eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None) f.someattr = 3 - eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None) + eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None) f = Foo() f.someattr = 3 - eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None) + eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None) - attributes.instance_state(f).commit(['someattr']) - eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), 3) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), 3) def test_scalar(self): class Foo(_base.BasicEntity): @@ -739,13 +739,13 @@ class HistoryTest(_base.ORMTest): f.someattr = "hi" eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['hi'], (), ())) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['hi'], ())) f.someattr = 'there' eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['there'], (), ['hi'])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['there'], ())) @@ -760,7 +760,7 @@ class HistoryTest(_base.ORMTest): f.someattr = 'old' eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['old'], (), ['new'])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['old'], ())) # setting None on uninitialized is currently a change for a scalar attribute @@ -778,7 +778,7 @@ class HistoryTest(_base.ORMTest): # set same value twice f = Foo() - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) f.someattr = 'one' eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], (), ())) f.someattr = 'two' @@ -799,7 +799,7 @@ class HistoryTest(_base.ORMTest): f.someattr = {'foo':'hi'} eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'hi'}], (), ())) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'hi'}], ())) eq_(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'}) @@ -807,7 +807,7 @@ class HistoryTest(_base.ORMTest): eq_(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'}) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'there'}], (), [{'foo':'hi'}])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'there'}], ())) @@ -819,7 +819,7 @@ class HistoryTest(_base.ORMTest): f.someattr = {'foo':'old'} eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'old'}], (), [{'foo':'new'}])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'old'}], ())) @@ -847,13 +847,13 @@ class HistoryTest(_base.ORMTest): f.someattr = hi eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], (), ())) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ())) f.someattr = there eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], (), [hi])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [there], ())) @@ -868,7 +868,7 @@ class HistoryTest(_base.ORMTest): f.someattr = old eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], (), ['new'])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [old], ())) # setting None on uninitialized is currently not a change for an object attribute @@ -887,7 +887,7 @@ class HistoryTest(_base.ORMTest): # set same value twice f = Foo() - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) f.someattr = 'one' eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], (), ())) f.someattr = 'two' @@ -915,13 +915,13 @@ class HistoryTest(_base.ORMTest): f.someattr = [hi] eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], [])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ())) f.someattr = [there] eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], [hi])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [there], ())) @@ -935,13 +935,13 @@ class HistoryTest(_base.ORMTest): f = Foo() collection = attributes.init_collection(attributes.instance_state(f), 'someattr') collection.append_without_event(new) - attributes.instance_state(f).commit_all() + attributes.instance_state(f).commit_all(attributes.instance_dict(f)) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ())) f.someattr = [old] eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [], [new])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [old], ())) def test_dict_collections(self): @@ -969,7 +969,7 @@ class HistoryTest(_base.ORMTest): f.someattr['there'] = there eq_(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set([hi, there]), set(), set())) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set(), set([hi, there]), set())) def test_object_collections_mutate(self): @@ -994,13 +994,13 @@ class HistoryTest(_base.ORMTest): f.someattr.append(hi) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], [])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ())) f.someattr.append(there) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [hi], [])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, there], ())) @@ -1010,7 +1010,7 @@ class HistoryTest(_base.ORMTest): f.someattr.append(old) f.someattr.append(new) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old, new], [hi], [there])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, old, new], ())) f.someattr.pop(0) @@ -1021,19 +1021,19 @@ class HistoryTest(_base.ORMTest): f.__dict__['id'] = 1 collection = attributes.init_collection(attributes.instance_state(f), 'someattr') collection.append_without_event(new) - attributes.instance_state(f).commit_all() + attributes.instance_state(f).commit_all(attributes.instance_dict(f)) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ())) f.someattr.append(old) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [new], [])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new, old], ())) f = Foo() collection = attributes.init_collection(attributes.instance_state(f), 'someattr') collection.append_without_event(new) - attributes.instance_state(f).commit_all() + attributes.instance_state(f).commit_all(attributes.instance_dict(f)) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ())) f.id = 1 @@ -1056,7 +1056,7 @@ class HistoryTest(_base.ORMTest): f.someattr.append(hi) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi, there, hi], [], [])) - attributes.instance_state(f).commit_all() + attributes.instance_state(f).commit_all(attributes.instance_dict(f)) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, there, hi], ())) f.someattr = [] diff --git a/test/orm/extendedattr.py b/test/orm/extendedattr.py index 69164ebafb..aec6c181f2 100644 --- a/test/orm/extendedattr.py +++ b/test/orm/extendedattr.py @@ -117,7 +117,7 @@ class UserDefinedExtensionTest(_base.ORMTest): u.user_id = 7 u.user_name = 'john' u.email_address = 'lala@123.com' - self.assert_(u.__dict__ == {'_my_state':u._my_state, '_goofy_dict':{'user_id':7, 'user_name':'john', 'email_address':'lala@123.com'}}) + self.assert_(u.__dict__ == {'_my_state':u._my_state, '_goofy_dict':{'user_id':7, 'user_name':'john', 'email_address':'lala@123.com'}}, u.__dict__) def test_basic(self): for base in (object, MyBaseClass, MyClass): @@ -135,7 +135,7 @@ class UserDefinedExtensionTest(_base.ORMTest): u.email_address = 'lala@123.com' self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') - attributes.instance_state(u).commit_all() + attributes.instance_state(u).commit_all(attributes.instance_dict(u)) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') u.user_name = 'heythere' @@ -182,7 +182,7 @@ class UserDefinedExtensionTest(_base.ORMTest): self.assertEquals(f.a, None) self.assertEquals(f.b, 12) - attributes.instance_state(f).commit_all() + attributes.instance_state(f).commit_all(attributes.instance_dict(f)) self.assertEquals(f.a, None) self.assertEquals(f.b, 12) @@ -272,8 +272,8 @@ class UserDefinedExtensionTest(_base.ORMTest): f1.bars.append(b1) self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], [])) - attributes.instance_state(f1).commit_all() - attributes.instance_state(b1).commit_all() + attributes.instance_state(f1).commit_all(attributes.instance_dict(f1)) + attributes.instance_state(b1).commit_all(attributes.instance_dict(b1)) self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), ((), ['f1'], ())) self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ((), [b1], ())) diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py index ddb4fa4ba5..d7f19a2cc0 100644 --- a/test/orm/inheritance/basic.py +++ b/test/orm/inheritance/basic.py @@ -5,6 +5,7 @@ from sqlalchemy.orm import * from sqlalchemy.orm import exc as orm_exc from testlib import * from testlib import fixtures +from orm import _base, _fixtures class O2MTest(ORMTest): """deals with inheritance and one-to-many relationships""" @@ -924,6 +925,49 @@ class OptimizedLoadTest(ORMTest): # the optimized load needs to return "None" so regular full-row loading proceeds s1 = sess.query(Base).get(s1.id) assert s1.sub == 's1sub' + +class PKDiscriminatorTest(_base.MappedTest): + def define_tables(self, metadata): + parents = Table('parents', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(60))) + + children = Table('children', metadata, + Column('id', Integer, ForeignKey('parents.id'), primary_key=True), + Column('type', Integer,primary_key=True), + Column('name', String(60))) + + @testing.resolve_artifact_names + def test_pk_as_discriminator(self): + class Parent(object): + def __init__(self, name=None): + self.name = name + + class Child(object): + def __init__(self, name=None): + self.name = name + + class A(Child): + pass + + mapper(Parent, parents, properties={ + 'children': relation(Child, backref='parent'), + }) + mapper(Child, children, polymorphic_on=children.c.type, + polymorphic_identity=1) + + mapper(A, inherits=Child, polymorphic_identity=2) + + s = create_session() + p = Parent('p1') + a = A('a1') + p.children.append(a) + s.add(p) + s.flush() + + assert a.id + assert a.type == 2 + class DeleteOrphanTest(ORMTest): def define_tables(self, metadata): diff --git a/test/orm/instrumentation.py b/test/orm/instrumentation.py index 081c46cdd8..fd15420d0a 100644 --- a/test/orm/instrumentation.py +++ b/test/orm/instrumentation.py @@ -1,8 +1,8 @@ import testenv; testenv.configure_for_tests() from testlib import sa -from testlib.sa import MetaData, Table, Column, Integer, ForeignKey -from testlib.sa.orm import mapper, relation, create_session, attributes, class_mapper +from testlib.sa import MetaData, Table, Column, Integer, ForeignKey, util +from testlib.sa.orm import mapper, relation, create_session, attributes, class_mapper, clear_mappers from testlib.testing import eq_, ne_ from testlib.compat import _function_named from orm import _base @@ -458,25 +458,9 @@ class MapperInitTest(_base.ORMTest): m = mapper(A, self.fixture()) - a = attributes.instance_state(A()) - assert isinstance(a, attributes.InstanceState) - assert type(a) is not attributes.InstanceState - - b = attributes.instance_state(B()) - assert isinstance(b, attributes.InstanceState) - assert type(b) is not attributes.InstanceState - # B is not mapped in the current implementation self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, B) - # the constructor of C is decorated too. - # we don't support unmapped subclasses in any case, - # users should not be expecting any particular behavior - # from this scenario. - c = attributes.instance_state(C(3)) - assert isinstance(c, attributes.InstanceState) - assert type(c) is not attributes.InstanceState - # C is not mapped in the current implementation self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, C) @@ -573,6 +557,10 @@ class OnLoadTest(_base.ORMTest): finally: del A + def tearDownAll(self): + clear_mappers() + attributes._install_lookup_strategy(util.symbol('native')) + class ExtendedEventsTest(_base.ORMTest): """Allow custom Events implementations.""" @@ -593,6 +581,7 @@ class ExtendedEventsTest(_base.ORMTest): assert isinstance(manager.events, MyEvents) + class NativeInstrumentationTest(_base.ORMTest): @with_lookup_strategy(sa.util.symbol('native')) def test_register_reserved_attribute(self): diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 8192b195ae..26a76301f2 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -1754,9 +1754,9 @@ class CompositeTypesTest(_base.MappedTest): return [self.x, self.y] __hash__ = None def __eq__(self, other): - return other.x == self.x and other.y == self.y + return isinstance(other, Point) and other.x == self.x and other.y == self.y def __ne__(self, other): - return not self.__eq__(other) + return not isinstance(other, Point) or not self.__eq__(other) class Graph(object): pass @@ -1822,6 +1822,12 @@ class CompositeTypesTest(_base.MappedTest): # query by columns eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, 19, 5)]) + e = g.edges[1] + e.end.x = e.end.y = None + sess.flush() + eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, None, None)]) + + @testing.resolve_artifact_names def test_pk(self): """Using a composite type as a primary key""" diff --git a/test/orm/merge.py b/test/orm/merge.py index 02f8563c18..3f832e33bb 100644 --- a/test/orm/merge.py +++ b/test/orm/merge.py @@ -221,6 +221,15 @@ class MergeTest(_fixtures.FixtureTest): Address(email_address='hoho@bar.com')])) eq_(on_load.called, 6) + @testing.resolve_artifact_names + def test_merge_empty_attributes(self): + mapper(User, dingalings) + u1 = User(id=1) + sess = create_session() + sess.merge(u1) + sess.flush() + assert u1.address_id is u1.data is None + @testing.resolve_artifact_names def test_attribute_cascade(self): """Merge of a persistent entity with two child persistent entities.""" diff --git a/test/orm/naturalpks.py b/test/orm/naturalpks.py index 980165fc0b..8efce660c3 100644 --- a/test/orm/naturalpks.py +++ b/test/orm/naturalpks.py @@ -14,20 +14,23 @@ class NaturalPKTest(_base.MappedTest): def define_tables(self, metadata): users = Table('users', metadata, Column('username', String(50), primary_key=True), - Column('fullname', String(100))) + Column('fullname', String(100)), + test_needs_fk=True) addresses = Table('addresses', metadata, Column('email', String(50), primary_key=True), - Column('username', String(50), ForeignKey('users.username', onupdate="cascade"))) + Column('username', String(50), ForeignKey('users.username', onupdate="cascade")), + test_needs_fk=True) items = Table('items', metadata, Column('itemname', String(50), primary_key=True), - Column('description', String(100))) + Column('description', String(100)), + test_needs_fk=True) users_to_items = Table('users_to_items', metadata, Column('username', String(50), ForeignKey('users.username', onupdate='cascade'), primary_key=True), Column('itemname', String(50), ForeignKey('items.itemname', onupdate='cascade'), primary_key=True), - ) + test_needs_fk=True) def setup_classes(self): class User(_base.ComparableEntity): @@ -101,8 +104,7 @@ class NaturalPKTest(_base.MappedTest): assert sess.query(User).get('ed').fullname == 'jack' - @testing.fails_on('mysql', 'FIXME: unknown') - @testing.fails_on('sqlite', 'FIXME: unknown') + @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') def test_onetomany_passive(self): self._test_onetomany(True) @@ -153,8 +155,7 @@ class NaturalPKTest(_base.MappedTest): self.assertEquals(User(username='fred', fullname='jack'), u1) - @testing.fails_on('sqlite', 'FIXME: unknown') - @testing.fails_on('mysql', 'FIXME: unknown') + @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') def test_manytoone_passive(self): self._test_manytoone(True) @@ -181,8 +182,6 @@ class NaturalPKTest(_base.MappedTest): u1.username = 'ed' - print id(a1), id(a2), id(u1) - print sa.orm.attributes.instance_state(u1).parents def go(): sess.flush() if passive_updates: @@ -198,8 +197,48 @@ class NaturalPKTest(_base.MappedTest): sess.expunge_all() self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) - @testing.fails_on('sqlite', 'FIXME: unknown') - @testing.fails_on('mysql', 'FIXME: unknown') + @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + def test_onetoone_passive(self): + self._test_onetoone(True) + + def test_onetoone_nonpassive(self): + self._test_onetoone(False) + + @testing.resolve_artifact_names + def _test_onetoone(self, passive_updates): + mapper(User, users, properties={ + "address":relation(Address, passive_updates=passive_updates, uselist=False) + }) + mapper(Address, addresses) + + sess = create_session() + u1 = User(username='jack', fullname='jack') + sess.add(u1) + sess.flush() + + a1 = Address(email='jack1') + u1.address = a1 + sess.add(a1) + sess.flush() + + u1.username = 'ed' + + def go(): + sess.flush() + if passive_updates: + sess.expire(u1, ['address']) + self.assert_sql_count(testing.db, go, 1) + else: + self.assert_sql_count(testing.db, go, 2) + + def go(): + sess.flush() + self.assert_sql_count(testing.db, go, 0) + + sess.expunge_all() + self.assertEquals([Address(username='ed')], sess.query(Address).all()) + + @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') def test_bidirectional_passive(self): self._test_bidirectional(True) @@ -230,6 +269,7 @@ class NaturalPKTest(_base.MappedTest): def go(): sess.flush() if passive_updates: + sess.expire(u1, ['addresses']) self.assert_sql_count(testing.db, go, 1) else: self.assert_sql_count(testing.db, go, 3) @@ -240,11 +280,11 @@ class NaturalPKTest(_base.MappedTest): u1 = sess.query(User).get('ed') assert len(u1.addresses) == 2 # load addresses u1.username = 'fred' - print "--------------------------------" def go(): sess.flush() # check that the passive_updates is on on the other side if passive_updates: + sess.expire(u1, ['addresses']) self.assert_sql_count(testing.db, go, 1) else: self.assert_sql_count(testing.db, go, 3) @@ -252,11 +292,11 @@ class NaturalPKTest(_base.MappedTest): self.assertEquals([Address(username='fred'), Address(username='fred')], sess.query(Address).all()) - @testing.fails_on('sqlite', 'FIXME: unknown') - @testing.fails_on('mysql', 'FIXME: unknown') + @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') def test_manytomany_passive(self): self._test_manytomany(True) + @testing.fails_on('mysql', 'the executemany() of the association table fails to report the correct row count') def test_manytomany_nonpassive(self): self._test_manytomany(False) @@ -355,13 +395,16 @@ class NonPKCascadeTest(_base.MappedTest): Table('users', metadata, Column('id', Integer, primary_key=True), Column('username', String(50), unique=True), - Column('fullname', String(100))) + Column('fullname', String(100)), + test_needs_fk=True) Table('addresses', metadata, Column('id', Integer, primary_key=True), Column('email', String(50)), Column('username', String(50), - ForeignKey('users.username', onupdate="cascade"))) + ForeignKey('users.username', onupdate="cascade")), + test_needs_fk=True + ) def setup_classes(self): class User(_base.ComparableEntity): @@ -369,8 +412,7 @@ class NonPKCascadeTest(_base.MappedTest): class Address(_base.ComparableEntity): pass - @testing.fails_on('sqlite', 'FIXME: unknown') - @testing.fails_on('mysql', 'FIXME: unknown') + @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') def test_onetomany_passive(self): self._test_onetomany(True) diff --git a/test/orm/onetoone.py b/test/orm/onetoone.py index 1ed3dcc619..be0375e48b 100644 --- a/test/orm/onetoone.py +++ b/test/orm/onetoone.py @@ -1,7 +1,7 @@ import testenv; testenv.configure_for_tests() from testlib import sa, testing from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation +from testlib.sa.orm import mapper, relation, create_session from orm import _base @@ -19,50 +19,56 @@ class O2OTest(_base.MappedTest): Column('description', String(100)), Column('jack_id', Integer, ForeignKey("jack.id"))) + @testing.resolve_artifact_names def setup_mappers(self): class Jack(_base.BasicEntity): pass class Port(_base.BasicEntity): pass - @testing.resolve_artifact_names - def test_1(self): - ctx = sa.orm.scoped_session(sa.orm.create_session) - mapper(Port, port, extension=ctx.extension) + @testing.resolve_artifact_names + def test_basic(self): + mapper(Port, port) mapper(Jack, jack, order_by=[jack.c.number], properties=dict( port=relation(Port, backref='jack', - uselist=False, lazy=True)), - extension=ctx.extension) + uselist=False, + )), + ) + + session = create_session() j = Jack(number='101') + session.add(j) p = Port(name='fa0/1') + session.add(p) + j.port=p - ctx.flush() + session.flush() jid = j.id pid = p.id - j=ctx.query(Jack).get(jid) - p=ctx.query(Port).get(pid) + j=session.query(Jack).get(jid) + p=session.query(Port).get(pid) assert p.jack is not None assert p.jack is j assert j.port is not None p.jack = None assert j.port is None - ctx.expunge_all() + session.expunge_all() - j = ctx.query(Jack).get(jid) - p = ctx.query(Port).get(pid) + j = session.query(Jack).get(jid) + p = session.query(Port).get(pid) j.port=None self.assert_(p.jack is None) - ctx.flush() + session.flush() - ctx.delete(j) - ctx.flush() + session.delete(j) + session.flush() if __name__ == "__main__": testenv.main() diff --git a/test/orm/query.py b/test/orm/query.py index 6531b234c6..07705c9256 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -366,7 +366,7 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): ) u7 = User(id=7) - attributes.instance_state(u7).commit_all() + attributes.instance_state(u7).commit_all(attributes.instance_dict(u7)) self._test(Address.user == u7, ":param_1 = addresses.user_id") diff --git a/test/orm/relationships.py b/test/orm/relationships.py index 88f132eae2..a0a8900b2c 100644 --- a/test/orm/relationships.py +++ b/test/orm/relationships.py @@ -835,7 +835,7 @@ class JoinConditionErrorTest(testing.TestBase): mapper(C2, t3) self.assertRaises(sa.exc.NoReferencedColumnError, compile_mappers) - + def test_join_error_raised(self): m = MetaData() t1 = Table('t1', m, @@ -1640,6 +1640,53 @@ class InvalidRelationEscalationTest(_base.MappedTest): "Could not locate any equated, locally mapped column pairs " "for primaryjoin condition", sa.orm.compile_mappers) + @testing.resolve_artifact_names + def test_ambiguous_fks(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, + primaryjoin=foos.c.id==bars.c.fid, + foreign_keys=[foos.c.id, bars.c.fid])}) + mapper(Bar, bars) + + self.assertRaisesMessage( + sa.exc.ArgumentError, + "Do the columns in 'foreign_keys' represent only the " + "'foreign' columns in this join condition ?", + sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_ambiguous_remoteside_o2m(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, + primaryjoin=foos.c.id==bars.c.fid, + foreign_keys=[bars.c.fid], + remote_side=[foos.c.id, bars.c.fid], + viewonly=True + )}) + mapper(Bar, bars) + + self.assertRaisesMessage( + sa.exc.ArgumentError, + "could not determine any local/remote column pairs", + sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_ambiguous_remoteside_m2o(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, + primaryjoin=foos.c.id==bars.c.fid, + foreign_keys=[foos.c.id], + remote_side=[foos.c.id, bars.c.fid], + viewonly=True + )}) + mapper(Bar, bars) + + self.assertRaisesMessage( + sa.exc.ArgumentError, + "could not determine any local/remote column pairs", + sa.orm.compile_mappers) + + @testing.resolve_artifact_names def test_no_equated_self_ref(self): mapper(Foo, foos, properties={ diff --git a/test/orm/session.py b/test/orm/session.py index 5a2229b16c..41c3fe7552 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import create_session, sessionmaker, attributes from testlib import engines, sa, testing, config from testlib.compat import gc_collect from testlib.sa import Table, Column, Integer, String, Sequence -from testlib.sa.orm import mapper, relation, backref +from testlib.sa.orm import mapper, relation, backref, eagerload from testlib.testing import eq_ from engine import _base as engine_base from orm import _base, _fixtures @@ -776,7 +776,66 @@ class SessionTest(_fixtures.FixtureTest): user = s.query(User).one() assert user.name == 'fred' assert s.identity_map + + @testing.resolve_artifact_names + def test_weakref_with_cycles_o2m(self): + s = sessionmaker()() + mapper(User, users, properties={ + "addresses":relation(Address, backref="user") + }) + mapper(Address, addresses) + s.add(User(name="ed", addresses=[Address(email_address="ed1")])) + s.commit() + + user = s.query(User).options(eagerload(User.addresses)).one() + user.addresses[0].user # lazyload + eq_(user, User(name="ed", addresses=[Address(email_address="ed1")])) + + del user + gc_collect() + assert len(s.identity_map) == 0 + user = s.query(User).options(eagerload(User.addresses)).one() + user.addresses[0].email_address='ed2' + user.addresses[0].user # lazyload + del user + gc_collect() + assert len(s.identity_map) == 2 + + s.commit() + user = s.query(User).options(eagerload(User.addresses)).one() + eq_(user, User(name="ed", addresses=[Address(email_address="ed2")])) + + @testing.resolve_artifact_names + def test_weakref_with_cycles_o2o(self): + s = sessionmaker()() + mapper(User, users, properties={ + "address":relation(Address, backref="user", uselist=False) + }) + mapper(Address, addresses) + s.add(User(name="ed", address=Address(email_address="ed1"))) + s.commit() + + user = s.query(User).options(eagerload(User.address)).one() + user.address.user + eq_(user, User(name="ed", address=Address(email_address="ed1"))) + + del user + gc_collect() + assert len(s.identity_map) == 0 + + user = s.query(User).options(eagerload(User.address)).one() + user.address.email_address='ed2' + user.address.user # lazyload + + del user + gc_collect() + assert len(s.identity_map) == 2 + + s.commit() + user = s.query(User).options(eagerload(User.address)).one() + eq_(user, User(name="ed", address=Address(email_address="ed2"))) + @testing.resolve_artifact_names def test_strong_ref(self): s = create_session(weak_identity_map=False) @@ -792,9 +851,9 @@ class SessionTest(_fixtures.FixtureTest): assert len(s.identity_map) == 1 user = s.query(User).one() - assert not s.identity_map.modified + assert not s.identity_map._modified user.name = 'u2' - assert s.identity_map.modified + assert s.identity_map._modified s.flush() eq_(users.select().execute().fetchall(), [(user.id, 'u2')]) diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index dd1b9b766c..f1b9123135 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -14,6 +14,7 @@ from orm import _base, _fixtures from engine import _base as engine_base import pickleable from testlib.assertsql import AllOf, CompiledSQL +import gc class UnitOfWorkTest(object): pass @@ -366,6 +367,26 @@ class MutableTypesTest(_base.MappedTest): "WHERE mutable_t.id = :mutable_t_id", {'mutable_t_id': f1.id, 'val': u'hi', 'data':f1.data})]) + @testing.resolve_artifact_names + def test_resurrect(self): + f1 = Foo() + f1.data = pickleable.Bar(4,5) + f1.val = u'hi' + + session = create_session(autocommit=False) + session.add(f1) + session.commit() + + f1.data.y = 19 + del f1 + + gc.collect() + assert len(session.identity_map) == 1 + + session.commit() + + assert session.query(Foo).one().data == pickleable.Bar(4, 19) + @testing.resolve_artifact_names def test_unicode(self): """Equivalent Unicode values are not flagged as changed.""" diff --git a/test/profiling/zoomark_orm.py b/test/profiling/zoomark_orm.py index 7a189f8731..5d7192261d 100644 --- a/test/profiling/zoomark_orm.py +++ b/test/profiling/zoomark_orm.py @@ -290,11 +290,11 @@ class ZooMarkTest(TestBase): def test_profile_1_create_tables(self): self.test_baseline_1_create_tables() - @profiling.function_call_count(12925, {'2.4':12478}) + @profiling.function_call_count(12178, {'2.4':12178}) def test_profile_1a_populate(self): self.test_baseline_1a_populate() - @profiling.function_call_count(1185, {'2.4':1184}) + @profiling.function_call_count(903, {'2.4':903}) def test_profile_2_insert(self): self.test_baseline_2_insert() @@ -310,7 +310,7 @@ class ZooMarkTest(TestBase): def test_profile_5_aggregates(self): self.test_baseline_5_aggregates() - @profiling.function_call_count(3545) + @profiling.function_call_count(3343) def test_profile_6_editing(self): self.test_baseline_6_editing()