From: Mike Bayer Date: Sat, 19 Jul 2008 21:40:34 +0000 (+0000) Subject: - A critical fix to dynamic relations allows the X-Git-Tag: rel_0_4_7~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f70ae049c6fdd14dec1240dc2fa0a0242df827cb;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - A critical fix to dynamic relations allows the "modified" history to be properly cleared after a flush() (backported from 0.5). --- diff --git a/CHANGES b/CHANGES index 102e71c94f..4973436de0 100644 --- a/CHANGES +++ b/CHANGES @@ -10,6 +10,10 @@ CHANGES that multiple contains() calls will not conflict with each other [ticket:1058] + - A critical fix to dynamic relations allows the + "modified" history to be properly cleared after + a flush() (backported from 0.5). + - fixed bug preventing merge() from functioning in conjunction with a comparable_property() diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 133ad99c89..19bdeab1dd 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -5,7 +5,7 @@ from sqlalchemy import exceptions, util, logging from sqlalchemy.orm import attributes, object_session, util as mapperutil, strategies from sqlalchemy.orm.query import Query from sqlalchemy.orm.mapper import has_identity, object_mapper - +from sqlalchemy.orm.util import _state_has_identity class DynaLoader(strategies.AbstractRelationLoader): def init_class_attribute(self): @@ -38,7 +38,8 @@ class DynamicAttributeImpl(attributes.AttributeImpl): return history.added_items + history.unchanged_items def fire_append_event(self, state, value, initiator): - state.modified = True + collection_history = self._modified_event(state) + collection_history.added_items.append(value) if self.trackparent and value is not None: self.sethasparent(value._state, True) @@ -47,7 +48,8 @@ class DynamicAttributeImpl(attributes.AttributeImpl): ext.append(instance, value, initiator or self) def fire_remove_event(self, state, value, initiator): - state.modified = True + collection_history = self._modified_event(state) + collection_history.deleted_items.append(value) if self.trackparent and value is not None: self.sethasparent(value._state, False) @@ -55,15 +57,28 @@ class DynamicAttributeImpl(attributes.AttributeImpl): instance = state.obj() for ext in self.extensions: ext.remove(instance, value, initiator or self) + + def _modified_event(self, state): + state.modified = True + if self.key not in state.committed_state: + state.committed_state[self.key] = CollectionHistory(self, state) + + # this is a hack to allow the _base.ComparableEntity fixture + # to work + state.dict[self.key] = True + + return state.committed_state[self.key] def set(self, state, value, initiator): if initiator is self: return - - old_collection = self.get(state).assign(value) - - # TODO: emit events ??? - state.modified = True + + collection_history = self._modified_event(state) + if _state_has_identity(state): + old_collection = list(self.get(state)) + else: + old_collection = [] + collection_history.replace(old_collection, value) def delete(self, *args, **kwargs): raise NotImplementedError() @@ -73,11 +88,11 @@ class DynamicAttributeImpl(attributes.AttributeImpl): return (c.added_items, c.unchanged_items, c.deleted_items) def _get_collection_history(self, state, passive=False): - try: - c = state.dict[self.key] - except KeyError: - state.dict[self.key] = c = CollectionHistory(self, state) - + if self.key in state.committed_state: + c = state.committed_state[self.key] + else: + c = CollectionHistory(self, state) + if not passive: return CollectionHistory(self, state, apply_to=c) else: @@ -85,15 +100,13 @@ class DynamicAttributeImpl(attributes.AttributeImpl): def append(self, state, value, initiator, passive=False): if initiator is not self: - self._get_collection_history(state, passive=True).added_items.append(value) self.fire_append_event(state, value, initiator) def remove(self, state, value, initiator, passive=False): if initiator is not self: - self._get_collection_history(state, passive=True).deleted_items.append(value) self.fire_remove_event(state, value, initiator) - + class AppenderQuery(Query): def __init__(self, attr, state): super(AppenderQuery, self).__init__(attr.target_mapper, None) @@ -152,15 +165,6 @@ class AppenderQuery(Query): q = q.order_by(self.attr.order_by) return q - def assign(self, collection): - instance = self.instance - if has_identity(instance): - oldlist = list(self) - else: - oldlist = [] - self.attr._get_collection_history(self.instance._state, passive=True).replace(oldlist, collection) - return oldlist - def append(self, item): self.attr.append(self.instance._state, item, None) diff --git a/test/orm/dynamic.py b/test/orm/dynamic.py index c38b278238..3a851495e4 100644 --- a/test/orm/dynamic.py +++ b/test/orm/dynamic.py @@ -118,7 +118,10 @@ class FlushTest(FixtureTest): sess.save(u1) sess.save(u2) sess.flush() - + + from sqlalchemy.orm import attributes + self.assertEquals(attributes.get_history(u1._state, 'addresses'), ([], [Address(email_address='lala@hoho.com')], [])) + sess.clear() # test the test fixture a little bit @@ -129,6 +132,31 @@ class FlushTest(FixtureTest): User(name='jack', addresses=[Address(email_address='lala@hoho.com')]), User(name='ed', addresses=[Address(email_address='foo@bar.com')]) ] == sess.query(User).all() + + def test_hasattr(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses)) + }) + u1 = User(name='jack') + + assert 'addresses' not in u1.__dict__.keys() + u1.addresses = [Address(email_address='test')] + assert 'addresses' in dir(u1) + + def test_rollback(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses)) + }) + sess = create_session(transactional=True, autoflush=True) + u1 = User(name='jack') + u1.addresses.append(Address(email_address='lala@hoho.com')) + sess.save(u1) + sess.flush() + sess.commit() + u1.addresses.append(Address(email_address='foo@bar.com')) + self.assertEquals(u1.addresses.all(), [Address(email_address='lala@hoho.com'), Address(email_address='foo@bar.com')]) + sess.rollback() + self.assertEquals(u1.addresses.all(), [Address(email_address='lala@hoho.com')]) @testing.fails_on('maxdb') def test_delete_nocascade(self):