From 6c4ad36cc9004db1d9dffe28a95e3556d14e2c82 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 14 Dec 2007 23:11:13 +0000 Subject: [PATCH] - simplified _mapper_registry further. its now just a weakkeydict of mapper->True, stores all mappers including non primaries, and is strictly used for the list of "to compile/dispose". - all global references are now weak referencing. if you del a mapped class and any dependent classes, its mapper and all dependencies fall out of scope. - attributes.py still had issues which were barely covered by tests. added way more tests (coverage.py still says 71%, doh) fixed things, took out unnecessary commit to states. attribute history is also asserted for ordering. --- lib/sqlalchemy/orm/__init__.py | 10 +-- lib/sqlalchemy/orm/attributes.py | 103 +++++++++---------------- lib/sqlalchemy/orm/dependency.py | 2 +- lib/sqlalchemy/orm/mapper.py | 39 +++++----- lib/sqlalchemy/orm/properties.py | 20 ++--- lib/sqlalchemy/util.py | 4 +- test/orm/attributes.py | 127 ++++++++++++++++++++++++++++--- test/orm/mapper.py | 13 ++++ test/orm/memusage.py | 68 ++++++++++++++++- test/orm/naturalpks.py | 12 ++- 10 files changed, 277 insertions(+), 121 deletions(-) diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index ac784ec08d..7e8d8b8bf8 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -592,9 +592,8 @@ def compile_mappers(): This is equivalent to calling ``compile()`` on any individual mapper. """ - if not _mapper_registry: - return - _mapper_registry.values()[0][0].compile() + for m in list(_mapper_registry): + m.compile() def clear_mappers(): """Remove all mappers that have been created thus far. @@ -604,11 +603,8 @@ def clear_mappers(): """ mapperlib._COMPILE_MUTEX.acquire() try: - for mapper in chain(*_mapper_registry.values()): + for mapper in _mapper_registry: mapper.dispose() - _mapper_registry.clear() - from sqlalchemy.orm import dependency - dependency.MapperStub.dispose(dependency.MapperStub) finally: mapperlib._COMPILE_MUTEX.release() diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index af589f3403..01be8813f0 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -429,11 +429,9 @@ class CollectionAttributeImpl(AttributeImpl): return _create_history(self, state, current) def fire_append_event(self, state, value, initiator): - if self.key not in state.committed_state: - if self.key in state.dict: - state.committed_state[self.key] = self.copy(state.dict[self.key]) - else: - state.committed_state[self.key] = NO_VALUE + if self.key not in state.committed_state and self.key in state.dict: + state.committed_state[self.key] = self.copy(state.dict[self.key]) + state.modified = True if self.trackparent and value is not None: @@ -443,18 +441,13 @@ class CollectionAttributeImpl(AttributeImpl): ext.append(instance, value, initiator or self) def fire_pre_remove_event(self, state, initiator): - if self.key not in state.committed_state: - if self.key in state.dict: - state.committed_state[self.key] = self.copy(state.dict[self.key]) - else: - state.committed_state[self.key] = NO_VALUE + if self.key not in state.committed_state and self.key in state.dict: + state.committed_state[self.key] = self.copy(state.dict[self.key]) def fire_remove_event(self, state, value, initiator): - if self.key not in state.committed_state: - if self.key in state.dict: - state.committed_state[self.key] = self.copy(state.dict[self.key]) - else: - state.committed_state[self.key] = NO_VALUE + if self.key not in state.committed_state and self.key in state.dict: + state.committed_state[self.key] = self.copy(state.dict[self.key]) + state.modified = True if self.trackparent and value is not None: @@ -492,12 +485,6 @@ class CollectionAttributeImpl(AttributeImpl): if initiator is self: return - if self.key not in state.committed_state: - if self.key in state.dict: - state.committed_state[self.key] = self.copy(state.dict[self.key]) - else: - state.committed_state[self.key] = NO_VALUE - collection = self.get_collection(state, passive=passive) if collection is PASSIVE_NORESULT: state.get_pending(self.key).append(value) @@ -509,12 +496,6 @@ class CollectionAttributeImpl(AttributeImpl): if initiator is self: return - if self.key not in state.committed_state: - if self.key in state.dict: - state.committed_state[self.key] = self.copy(state.dict[self.key]) - else: - state.committed_state[self.key] = NO_VALUE - collection = self.get_collection(state, passive=passive) if collection is PASSIVE_NORESULT: state.get_pending(self.key).remove(value) @@ -533,12 +514,6 @@ class CollectionAttributeImpl(AttributeImpl): if initiator is self: return - if self.key not in state.committed_state: - if self.key in state.dict: - state.committed_state[self.key] = self.copy(state.dict[self.key]) - else: - state.committed_state[self.key] = NO_VALUE - # we need a CollectionAdapter to adapt the incoming value to an # assignable iterable. pulling a new collection first so that # an adaptation exception does not trigger a lazy load of the @@ -547,6 +522,8 @@ class CollectionAttributeImpl(AttributeImpl): new_values = list(new_collection.adapt_like_to_iterable(value)) old = self.get(state) + state.committed_state[self.key] = self.copy(old) + old_collection = self.get_collection(state, old) idset = util.IdentitySet @@ -576,16 +553,30 @@ class CollectionAttributeImpl(AttributeImpl): """ collection, user_data = self._build_collection(state) - self._load_collection(state, value or [], emit_events=False, - collection=collection) - value = user_data - if state.committed_state is not None: - state.commit_attr(self, value) - # remove per-instance callable, if any + if value: + for item in value: + collection.append_without_event(item) + state.callables.pop(self.key, None) - state.dict[self.key] = value - return value + state.dict[self.key] = user_data + + if self.key in state.pending: + # pending items. commit loaded data, add/remove new data + state.committed_state[self.key] = list(value or []) + added = state.pending[self.key].added_items + removed = state.pending[self.key].deleted_items + for item in added: + collection.append_without_event(item) + for item in removed: + collection.remove_without_event(item) + del state.pending[self.key] + elif self.key in state.committed_state: + # no pending items. remove committed state if any. + # (this can occur with an expired attribute) + del state.committed_state[self.key] + + return user_data def _build_collection(self, state): """build a new, blank collection and return it wrapped in a CollectionAdapter.""" @@ -594,34 +585,12 @@ class CollectionAttributeImpl(AttributeImpl): collection = collections.CollectionAdapter(self, state, user_data) return collection, user_data - def _load_collection(self, state, values, emit_events=True, collection=None): - """given an empty CollectionAdapter, load the collection with current values. - - Loads the collection from lazy callables in all cases. - """ + def get_collection(self, state, user_data=None, passive=False): + """retrieve the CollectionAdapter associated with the given state. - collection = collection or self.get_collection(state) - if values is None: - return - - appender = emit_events and collection.append_with_event or collection.append_without_event + Creates a new CollectionAdapter if one does not exist. - if self.key in state.pending: - # move 'pending' items into the newly loaded collection - added = state.pending[self.key].added_items - removed = state.pending[self.key].deleted_items - for item in values: - if item not in removed: - appender(item) - for item in added: - appender(item) - del state.pending[self.key] - else: - for item in values: - appender(item) - - def get_collection(self, state, user_data=None, passive=False): - """retrieve the CollectionAdapter associated with the given state.""" + """ if user_data is None: user_data = self.get(state, passive=passive) diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index 0a097fd24a..8340ccdcc6 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -481,7 +481,7 @@ class MapperStub(object): """ __metaclass__ = util.ArgSingleton - + def __init__(self, parent, mapper, key): self.mapper = mapper self.base_mapper = self diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 6ea45a598a..d8c6a4023f 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -183,7 +183,14 @@ class Mapper(object): return False def get_property(self, key, resolve_synonyms=False, raiseerr=True): - """return MapperProperty with the given key.""" + """return a MapperProperty associated with the given key.""" + + self.compile() + return self._get_property(key, resolve_synonyms=resolve_synonyms, raiseerr=raiseerr) + + def _get_property(self, key, resolve_synonyms=False, raiseerr=True): + """private in-compilation version of get_property().""" + prop = self.__props.get(key, None) if resolve_synonyms: while isinstance(prop, SynonymProperty): @@ -193,6 +200,7 @@ class Mapper(object): return prop def iterate_properties(self): + self.compile() return self.__props.itervalues() iterate_properties = property(iterate_properties, doc="returns an iterator of all MapperProperty objects.") @@ -200,11 +208,15 @@ class Mapper(object): raise NotImplementedError("Public collection of MapperProperty objects is provided by the get_property() and iterate_properties accessors.") properties = property(properties) + compiled = property(lambda self:self.__props_init, doc="return True if this mapper is compiled") + def dispose(self): # disaable any attribute-based compilation self.__props_init = True - if hasattr(self.class_, 'c'): + try: del self.class_.c + except AttributeError: + pass if not self.non_primary and self.entity_name in self._class_state.mappers: del self._class_state.mappers[self.entity_name] if not self._class_state.mappers: @@ -221,24 +233,16 @@ class Mapper(object): # double-check inside mutex if self.__props_init: return self + # initialize properties on all mappers - for mapper in chain(*_mapper_registry.values()): + for mapper in list(_mapper_registry): if not mapper.__props_init: mapper.__initialize_properties() - # if we're not primary, compile us - if self.non_primary: - self.__initialize_properties() - return self finally: _COMPILE_MUTEX.release() - def _check_compile(self): - if self.non_primary and not self.__props_init: - self.__initialize_properties() - return self - def __initialize_properties(self): """Call the ``init()`` method on all ``MapperProperties`` attached to this mapper. @@ -727,6 +731,7 @@ class Mapper(object): if self.non_primary: self._class_state = self.class_._class_state + _mapper_registry[self] = True return if not self.non_primary and '_class_state' in self.class_.__dict__ and (self.entity_name in self.class_._class_state.mappers): @@ -743,15 +748,9 @@ class Mapper(object): attributes.register_class(self.class_, extra_init=extra_init, on_exception=on_exception) self._class_state = self.class_._class_state - if self._class_state not in _mapper_registry: - _mapper_registry[self._class_state] = [] + _mapper_registry[self] = True - _COMPILE_MUTEX.acquire() - try: - _mapper_registry[self._class_state].append(self) - self.class_._class_state.mappers[self.entity_name] = self - finally: - _COMPILE_MUTEX.release() + self.class_._class_state.mappers[self.entity_name] = self for ext in util.to_list(self.extension, []): ext.instrument_class(self, self.class_) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index b6d6cef638..027cefd692 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -150,7 +150,7 @@ class SynonymProperty(MapperProperty): def do_init(self): class_ = self.parent.class_ - aliased_property = self.parent.get_property(self.key, resolve_synonyms=True) + aliased_property = self.parent._get_property(self.key, resolve_synonyms=True) self.logger.info("register managed attribute %s on class %s" % (self.key, class_.__name__)) if self.instrument is None: class SynonymProp(object): @@ -396,23 +396,23 @@ class PropertyLoader(StrategizedProperty): def _determine_targets(self): if isinstance(self.argument, type): - self.mapper = mapper.class_mapper(self.argument, entity_name=self.entity_name, compile=False)._check_compile() + self.mapper = mapper.class_mapper(self.argument, entity_name=self.entity_name, compile=False) elif isinstance(self.argument, mapper.Mapper): - self.mapper = self.argument._check_compile() + self.mapper = self.argument else: raise exceptions.ArgumentError("relation '%s' expects a class or a mapper argument (received: %s)" % (self.key, type(self.argument))) # ensure the "select_mapper", if different from the regular target mapper, is compiled. - self.mapper.get_select_mapper()._check_compile() + self.mapper.get_select_mapper() if not self.parent.concrete: for inheriting in self.parent.iterate_to_root(): - if inheriting is not self.parent and inheriting.get_property(self.key, raiseerr=False): + if inheriting is not self.parent and inheriting._get_property(self.key, raiseerr=False): warnings.warn(RuntimeWarning("Warning: relation '%s' on mapper '%s' supercedes the same relation on inherited mapper '%s'; this can cause dependency issues during flush" % (self.key, self.parent, inheriting))) if self.association is not None: if isinstance(self.association, type): - self.association = mapper.class_mapper(self.association, entity_name=self.entity_name, compile=False)._check_compile() + self.association = mapper.class_mapper(self.association, entity_name=self.entity_name, compile=False) self.target = self.mapper.mapped_table self.select_mapper = self.mapper.get_select_mapper() @@ -650,7 +650,7 @@ class PropertyLoader(StrategizedProperty): if self.backref is not None: self.backref.compile(self) - elif not mapper.class_mapper(self.parent.class_).get_property(self.key, raiseerr=False): + elif not mapper.class_mapper(self.parent.class_, compile=False)._get_property(self.key, raiseerr=False): raise exceptions.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'. New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (self.key, self.parent.class_.__name__, self.parent.class_.__name__)) super(PropertyLoader, self).do_init() @@ -727,7 +727,7 @@ class BackRef(object): self.prop = prop mapper = prop.mapper.primary_mapper() - if mapper.get_property(self.key, raiseerr=False) is None: + if mapper._get_property(self.key, raiseerr=False) is None: pj = self.kwargs.pop('primaryjoin', None) sj = self.kwargs.pop('secondaryjoin', None) @@ -742,8 +742,8 @@ class BackRef(object): mapper._compile_property(self.key, relation); - prop.reverse_property = mapper.get_property(self.key) - mapper.get_property(self.key).reverse_property = prop + prop.reverse_property = mapper._get_property(self.key) + mapper._get_property(self.key).reverse_property = prop else: raise exceptions.ArgumentError("Error creating backref '%s' on relation '%s': property of that name exists on mapper '%s'" % (self.key, prop, mapper)) diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 3e26217c9b..463d4b8afd 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import itertools, sys, warnings, sets +import itertools, sys, warnings, sets, weakref import __builtin__ from sqlalchemy import exceptions, logging @@ -113,7 +113,7 @@ def flatten_iterator(x): yield elem class ArgSingleton(type): - instances = {} + instances = weakref.WeakValueDictionary() def dispose(cls): for key in list(ArgSingleton.instances): diff --git a/test/orm/attributes.py b/test/orm/attributes.py index 3b3ed61025..74d94a4b4d 100644 --- a/test/orm/attributes.py +++ b/test/orm/attributes.py @@ -232,10 +232,7 @@ class AttributesTest(PersistTest): x = Foo() x._state.commit_all() x.col2.append(bar4) - (added, unchanged, deleted) = attributes.get_history(x._state, 'col2') - - self.assertEquals(set(unchanged), set([bar1, bar2, bar3])) - self.assertEquals(added, [bar4]) + self.assertEquals(attributes.get_history(x._state, 'col2'), ([bar4], [bar1, bar2, bar3], [])) def test_parenttrack(self): class Foo(object):pass @@ -729,19 +726,19 @@ class HistoryTest(PersistTest): self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [hi], [])) f._state.commit(['someattr']) - self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set(), set([there, hi]), set())) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi, there], [])) f.someattr.remove(there) - self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set(), set([hi]), set([there]))) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], [there])) f.someattr.append(old) f.someattr.append(new) - self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set([new, old]), set([hi]), set([there]))) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old, new], [hi], [there])) f._state.commit(['someattr']) - self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set(), set([new, old, hi]), set())) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi, old, new], [])) f.someattr.pop(0) - self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set(), set([new, old]), set([hi]))) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [old, new], [hi])) # case 2. object with direct settings (similar to a load operation) f = Foo() @@ -755,7 +752,7 @@ class HistoryTest(PersistTest): self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old], [new], [])) f._state.commit(['someattr']) - self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set([]), set([old, new]), set([]))) + self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new, old], [])) f = Foo() collection = attributes.init_collection(f, 'someattr') @@ -791,10 +788,118 @@ class HistoryTest(PersistTest): b2 = Bar() f1.bars.append(b2) - self.assertEquals(tuple([set(x) for x in attributes.get_history(f1._state, 'bars')]), (set([b1, b2]), set([]), set([]))) + self.assertEquals(attributes.get_history(f1._state, 'bars'), ([b1, b2], [], [])) self.assertEquals(attributes.get_history(b1._state, 'foo'), ([f1], [], [])) self.assertEquals(attributes.get_history(b2._state, 'foo'), ([f1], [], [])) + + def test_lazy_backref_collections(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass + + lazy_load = [] + def lazyload(instance): + def load(): + return lazy_load + return load + + attributes.register_class(Foo) + attributes.register_class(Bar) + attributes.register_attribute(Foo, 'bars', uselist=True, extension=attributes.GenericBackrefExtension('foo'), trackparent=True, callable_=lazyload, useobject=True) + attributes.register_attribute(Bar, 'foo', uselist=False, extension=attributes.GenericBackrefExtension('bars'), trackparent=True, useobject=True) + + bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)] + lazy_load = [bar1, bar2, bar3] + + f = Foo() + bar4 = Bar() + bar4.foo = f + self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar2, bar3], [])) + + lazy_load = None + f = Foo() + bar4 = Bar() + bar4.foo = f + self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [], [])) + + lazy_load = [bar1, bar2, bar3] + f._state.trigger = lazyload(f) + f._state.expire_attributes(['bars']) + self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar2, bar3], [])) + + def test_collections_via_lazyload(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass + + lazy_load = [] + def lazyload(instance): + def load(): + return lazy_load + return load + + attributes.register_class(Foo) + attributes.register_class(Bar) + attributes.register_attribute(Foo, 'bars', uselist=True, callable_=lazyload, trackparent=True, useobject=True) + + bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)] + lazy_load = [bar1, bar2, bar3] + + f = Foo() + f.bars = [] + self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [], [bar1, bar2, bar3])) + + f = Foo() + f.bars.append(bar4) + self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar2, bar3], []) ) + + f = Foo() + f.bars.remove(bar2) + self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar3], [bar2])) + f.bars.append(bar4) + self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar3], [bar2])) + + f = Foo() + del f.bars[1] + self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar3], [bar2])) + + lazy_load = None + f = Foo() + f.bars.append(bar2) + self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar2], [], [])) + + def test_scalar_via_lazyload(self): + class Foo(fixtures.Base): + pass + class Bar(fixtures.Base): + pass + + lazy_load = None + def lazyload(instance): + def load(): + return lazy_load + return load + + attributes.register_class(Foo) + attributes.register_class(Bar) + attributes.register_attribute(Foo, 'bar', uselist=False, callable_=lazyload, trackparent=True, useobject=True) + bar1, bar2 = [Bar(id=1), Bar(id=2)] + lazy_load = bar1 + + f = Foo() + self.assertEquals(attributes.get_history(f._state, 'bar'), ([], [bar1], [])) + f = Foo() + f.bar = None + self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], [bar1])) + + f = Foo() + f.bar = bar2 + self.assertEquals(attributes.get_history(f._state, 'bar'), ([bar2], [], [bar1])) + f.bar = bar1 + self.assertEquals(attributes.get_history(f._state, 'bar'), ([], [bar1], [])) if __name__ == "__main__": testbase.main() diff --git a/test/orm/mapper.py b/test/orm/mapper.py index d8b552c69c..a1ca19a64f 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -163,6 +163,19 @@ class MapperTest(MapperSuperTest): mapper(Foo, addresses, inherits=User) assert getattr(Foo().__class__, 'user_name').impl is not None + def test_compileon_getprops(self): + m =mapper(User, users) + + assert not m.compiled + assert list(m.iterate_properties) + assert m.compiled + clear_mappers() + + m= mapper(User, users) + assert not m.compiled + assert m.get_property('user_name') + assert m.compiled + def test_add_property(self): assert_col = [] class User(object): diff --git a/test/orm/memusage.py b/test/orm/memusage.py index cd109f4fa1..f0fba31234 100644 --- a/test/orm/memusage.py +++ b/test/orm/memusage.py @@ -182,7 +182,73 @@ class MemUsageTest(AssertMixin): for a in alist: sess.delete(a) sess.flush() - clear_mappers() + + # dont need to clear_mappers() + del B + del A + + metadata.create_all() + try: + go() + finally: + metadata.drop_all() + + def test_with_manytomany(self): + metadata = MetaData(testbase.db) + + table1 = Table("mytable", metadata, + Column('col1', Integer, primary_key=True), + Column('col2', String(30)) + ) + + table2 = Table("mytable2", metadata, + Column('col1', Integer, primary_key=True), + Column('col2', String(30)), + ) + + table3 = Table('t1tot2', metadata, + Column('t1', Integer, ForeignKey('mytable.col1')), + Column('t2', Integer, ForeignKey('mytable2.col1')), + ) + + @profile_memory + def go(): + class A(Base): + pass + class B(Base): + pass + + mapper(A, table1, properties={ + 'bs':relation(B, secondary=table3, backref='as') + }) + mapper(B, table2) + + sess = create_session() + a1 = A(col2='a1') + a2 = A(col2='a2') + b1 = B(col2='b1') + b2 = B(col2='b2') + a1.bs.append(b1) + a2.bs.append(b2) + for x in [a1,a2]: + sess.save(x) + sess.flush() + sess.clear() + + alist = sess.query(A).all() + self.assertEquals( + [ + A(bs=[B(col2='b1')]), A(bs=[B(col2='b2')]) + ], + alist) + + for a in alist: + sess.delete(a) + sess.flush() + + # dont need to clear_mappers() + del B + del A metadata.create_all() try: diff --git a/test/orm/naturalpks.py b/test/orm/naturalpks.py index 515c30e298..069aea0b35 100644 --- a/test/orm/naturalpks.py +++ b/test/orm/naturalpks.py @@ -310,11 +310,14 @@ class NonPKCascadeTest(ORMTest): sess.flush() a1 = u1.addresses[0] + self.assertEquals(select([addresses.c.username]).execute().fetchall(), [('jack',), ('jack',)]) + assert sess.get(Address, a1.id) is u1.addresses[0] u1.username = 'ed' sess.flush() assert u1.addresses[0].username == 'ed' + self.assertEquals(select([addresses.c.username]).execute().fetchall(), [('ed',), ('ed',)]) sess.clear() self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) @@ -329,13 +332,18 @@ class NonPKCascadeTest(ORMTest): self.assert_sql_count(testbase.db, go, 1) # test passive_updates=True; update user sess.clear() assert User(username='jack', addresses=[Address(username='jack'), Address(username='jack')]) == sess.get(User, u1.id) - + sess.clear() + u1 = sess.get(User, u1.id) u1.addresses = [] u1.username = 'fred' sess.flush() sess.clear() - assert sess.get(Address, a1.id).username is None + a1 = sess.get(Address, a1.id) + self.assertEquals(a1.username, None) + + self.assertEquals(select([addresses.c.username]).execute().fetchall(), [(None,), (None,)]) + u1 = sess.get(User, u1.id) self.assertEquals(User(username='fred', fullname='jack'), u1) -- 2.47.3