)
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.util import _state_has_identity, has_identity
-from sqlalchemy.orm import attributes
+from sqlalchemy.orm import attributes, collections
class DynaLoader(strategies.AbstractRelationLoader):
def init_class_attribute(self, mapper):
if initiator is self:
return
+ self._set_iterable(state, value)
+
+ def _set_iterable(self, state, iterable, adapter=None):
+
collection_history = self._modified_event(state)
+ new_values = list(iterable)
+
if _state_has_identity(state):
old_collection = list(self.get(state))
else:
old_collection = []
- collection_history.replace(old_collection, value)
+
+ collections.bulk_replace(new_values, DynCollectionAdapter(self, state, old_collection), DynCollectionAdapter(self, state, new_values))
def delete(self, *args, **kwargs):
raise NotImplementedError()
if initiator is not self:
self.fire_remove_event(state, value, initiator)
+class DynCollectionAdapter(object):
+ """the dynamic analogue to orm.collections.CollectionAdapter"""
+
+ def __init__(self, attr, owner_state, data):
+ self.attr = attr
+ self.state = owner_state
+ self.data = data
+
+ def __iter__(self):
+ return iter(self.data)
+
+ def append_with_event(self, item, initiator=None):
+ self.attr.append(self.state, item, initiator)
+
+ def remove_with_event(self, item, initiator=None):
+ self.attr.remove(self.state, item, initiator)
+
+ def append_without_event(self, item):
+ pass
+
+ def remove_without_event(self, item):
+ pass
class AppenderMixin(object):
query_class = None
self.deleted_items = []
self.added_items = []
self.unchanged_items = []
-
- def replace(self, olditems, newitems):
- self.added_items = newitems
- self.deleted_items = olditems
from sqlalchemy.orm import dynamic_loader, backref
from testlib import testing
from testlib.sa import Table, Column, Integer, String, ForeignKey, desc, select, func
-from testlib.sa.orm import mapper, relation, create_session, Query
+from testlib.sa.orm import mapper, relation, create_session, Query, attributes
from testlib.testing import eq_
from testlib.compat import _function_named
from orm import _base, _fixtures
assert type(q).__name__ == 'MyQuery'
-class FlushTest(_fixtures.FixtureTest):
+class SessionTest(_fixtures.FixtureTest):
run_inserts = None
@testing.resolve_artifact_names
(a2.id, u1.id, 'bar')
]
+
+ @testing.resolve_artifact_names
+ def test_merge(self):
+ mapper(User, users, properties={
+ 'addresses':dynamic_loader(mapper(Address, addresses), order_by=addresses.c.email_address)
+ })
+ sess = create_session()
+ u1 = User(name='jack')
+ a1 = Address(email_address='a1')
+ a2 = Address(email_address='a2')
+ a3 = Address(email_address='a3')
+
+ u1.addresses.append(a2)
+ u1.addresses.append(a3)
+
+ sess.add_all([u1, a1])
+ sess.flush()
+
+ u1 = User(id=u1.id, name='jack')
+ u1.addresses.append(a1)
+ u1.addresses.append(a3)
+ u1 = sess.merge(u1)
+ assert attributes.get_history(u1, 'addresses') == (
+ [a1],
+ [a3],
+ [a2]
+ )
+
+ sess.flush()
+
+ eq_(
+ list(u1.addresses),
+ [a1, a3]
+ )
@testing.resolve_artifact_names
- def test_basic(self):
+ def test_flush(self):
mapper(User, users, properties={
'addresses':dynamic_loader(mapper(Address, addresses))
})
assert 'addresses' not in u1.__dict__.keys()
u1.addresses = [Address(email_address='test')]
assert 'addresses' in dir(u1)
+
+ @testing.resolve_artifact_names
+ def test_collection_set(self):
+ mapper(User, users, properties={
+ 'addresses':dynamic_loader(mapper(Address, addresses), order_by=addresses.c.email_address)
+ })
+ sess = create_session(autoflush=True, autocommit=False)
+ u1 = User(name='jack')
+ a1 = Address(email_address='a1')
+ a2 = Address(email_address='a2')
+ a3 = Address(email_address='a3')
+ a4 = Address(email_address='a4')
+
+ sess.add(u1)
+ u1.addresses = [a1, a3]
+ assert list(u1.addresses) == [a1, a3]
+ u1.addresses = [a1, a2, a4]
+ assert list(u1.addresses) == [a1, a2, a4]
+ u1.addresses = [a2, a3]
+ assert list(u1.addresses) == [a2, a3]
+ u1.addresses = []
+ assert list(u1.addresses) == []
+
+
+
@testing.resolve_artifact_names
def test_rollback(self):
test_backref = _function_named(
test_backref, "test%s%s" % ((autoflush and "_autoflush" or ""),
(saveuser and "_saveuser" or "_savead")))
- setattr(FlushTest, test_backref.__name__, test_backref)
+ setattr(SessionTest, test_backref.__name__, test_backref)
for autoflush in (False, True):
for saveuser in (False, True):