from sqlalchemy.orm import attributes, object_session, util as mapperutil, strategies
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.mapper import has_identity, object_mapper
-
+from sqlalchemy.orm.util import _state_has_identity
class DynaLoader(strategies.AbstractRelationLoader):
def init_class_attribute(self):
return history.added_items + history.unchanged_items
def fire_append_event(self, state, value, initiator):
- state.modified = True
+ collection_history = self._modified_event(state)
+ collection_history.added_items.append(value)
if self.trackparent and value is not None:
self.sethasparent(value._state, True)
ext.append(instance, value, initiator or self)
def fire_remove_event(self, state, value, initiator):
- state.modified = True
+ collection_history = self._modified_event(state)
+ collection_history.deleted_items.append(value)
if self.trackparent and value is not None:
self.sethasparent(value._state, False)
instance = state.obj()
for ext in self.extensions:
ext.remove(instance, value, initiator or self)
+
+ def _modified_event(self, state):
+ state.modified = True
+ if self.key not in state.committed_state:
+ state.committed_state[self.key] = CollectionHistory(self, state)
+
+ # this is a hack to allow the _base.ComparableEntity fixture
+ # to work
+ state.dict[self.key] = True
+
+ return state.committed_state[self.key]
def set(self, state, value, initiator):
if initiator is self:
return
-
- old_collection = self.get(state).assign(value)
-
- # TODO: emit events ???
- state.modified = True
+
+ collection_history = self._modified_event(state)
+ if _state_has_identity(state):
+ old_collection = list(self.get(state))
+ else:
+ old_collection = []
+ collection_history.replace(old_collection, value)
def delete(self, *args, **kwargs):
raise NotImplementedError()
return (c.added_items, c.unchanged_items, c.deleted_items)
def _get_collection_history(self, state, passive=False):
- try:
- c = state.dict[self.key]
- except KeyError:
- state.dict[self.key] = c = CollectionHistory(self, state)
-
+ if self.key in state.committed_state:
+ c = state.committed_state[self.key]
+ else:
+ c = CollectionHistory(self, state)
+
if not passive:
return CollectionHistory(self, state, apply_to=c)
else:
def append(self, state, value, initiator, passive=False):
if initiator is not self:
- self._get_collection_history(state, passive=True).added_items.append(value)
self.fire_append_event(state, value, initiator)
def remove(self, state, value, initiator, passive=False):
if initiator is not self:
- self._get_collection_history(state, passive=True).deleted_items.append(value)
self.fire_remove_event(state, value, initiator)
-
+
class AppenderQuery(Query):
def __init__(self, attr, state):
super(AppenderQuery, self).__init__(attr.target_mapper, None)
q = q.order_by(self.attr.order_by)
return q
- def assign(self, collection):
- instance = self.instance
- if has_identity(instance):
- oldlist = list(self)
- else:
- oldlist = []
- self.attr._get_collection_history(self.instance._state, passive=True).replace(oldlist, collection)
- return oldlist
-
def append(self, item):
self.attr.append(self.instance._state, item, None)
sess.save(u1)
sess.save(u2)
sess.flush()
-
+
+ from sqlalchemy.orm import attributes
+ self.assertEquals(attributes.get_history(u1._state, 'addresses'), ([], [Address(email_address='lala@hoho.com')], []))
+
sess.clear()
# test the test fixture a little bit
User(name='jack', addresses=[Address(email_address='lala@hoho.com')]),
User(name='ed', addresses=[Address(email_address='foo@bar.com')])
] == sess.query(User).all()
+
+ def test_hasattr(self):
+ mapper(User, users, properties={
+ 'addresses':dynamic_loader(mapper(Address, addresses))
+ })
+ u1 = User(name='jack')
+
+ assert 'addresses' not in u1.__dict__.keys()
+ u1.addresses = [Address(email_address='test')]
+ assert 'addresses' in dir(u1)
+
+ def test_rollback(self):
+ mapper(User, users, properties={
+ 'addresses':dynamic_loader(mapper(Address, addresses))
+ })
+ sess = create_session(transactional=True, autoflush=True)
+ u1 = User(name='jack')
+ u1.addresses.append(Address(email_address='lala@hoho.com'))
+ sess.save(u1)
+ sess.flush()
+ sess.commit()
+ u1.addresses.append(Address(email_address='foo@bar.com'))
+ self.assertEquals(u1.addresses.all(), [Address(email_address='lala@hoho.com'), Address(email_address='foo@bar.com')])
+ sess.rollback()
+ self.assertEquals(u1.addresses.all(), [Address(email_address='lala@hoho.com')])
@testing.fails_on('maxdb')
def test_delete_nocascade(self):