From 9e4052dc8be2451d1c48bb059da150ce41ddc86f Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 5 Dec 2007 20:43:16 +0000 Subject: [PATCH] - a major behavioral change to collection-based backrefs: they no longer trigger lazy loads ! "reverse" adds and removes are queued up and are merged with the collection when it is actually read from and loaded; but do not trigger a load beforehand. For users who have noticed this behavior, this should be much more convenient than using dynamic relations in some cases; for those who have not, you might notice your apps using a lot fewer queries than before in some situations. [ticket:871] --- CHANGES | 10 +- lib/sqlalchemy/orm/attributes.py | 109 +++++++++++--- lib/sqlalchemy/orm/dynamic.py | 4 +- test/orm/attributes.py | 242 +++++++++++++++++++++---------- test/orm/lazy_relations.py | 34 +++++ test/orm/mapper.py | 8 +- test/orm/unitofwork.py | 3 +- 7 files changed, 305 insertions(+), 105 deletions(-) diff --git a/CHANGES b/CHANGES index 935bfb6ce3..3b212fe82f 100644 --- a/CHANGES +++ b/CHANGES @@ -31,7 +31,15 @@ CHANGES - from_obj keyword argument to select() can be a scalar or a list. - orm - + - a major behavioral change to collection-based backrefs: they no + longer trigger lazy loads ! "reverse" adds and removes + are queued up and are merged with the collection when it is + actually read from and loaded; but do not trigger a load beforehand. + For users who have noticed this behavior, this should be much more + convenient than using dynamic relations in some cases; for those who + have not, you might notice your apps using a lot fewer queries than + before in some situations. [ticket:871] + - new synonym() behavior: an attribute will be placed on the mapped class, if one does not exist already, in all cases. if a property already exists on the class, the synonym will decorate the property diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 8268d0816c..bb7085402d 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -151,6 +151,7 @@ class AttributeImpl(object): value = state.dict[self.key] if value is not NO_VALUE: state.committed_state[self.key] = self.copy(value) + state.pending.pop(self.key, None) def hasparent(self, state, optimistic=False): """Return the boolean value of a `hasparent` flag attached to the given item. @@ -181,7 +182,7 @@ class AttributeImpl(object): current = self.get(state, passive=passive) if current is PASSIVE_NORESULT: return None - return AttributeHistory(self, state, current, passive=passive) + return AttributeHistory(self, state, current) def set_callable(self, state, callable_, clear=False): """Set a callable function for this attribute on the given object. @@ -249,10 +250,10 @@ class AttributeImpl(object): # Return a new, empty value return self.initialize(state) - def append(self, state, value, initiator): + def append(self, state, value, initiator, passive=False): self.set(state, value, initiator) - def remove(self, state, value, initiator): + def remove(self, state, value, initiator, passive=False): self.set(state, None, initiator) def set(self, state, value, initiator): @@ -433,17 +434,27 @@ class CollectionAttributeImpl(AttributeImpl): state.dict[self.key] = user_data return user_data - def append(self, state, value, initiator): + def append(self, state, value, initiator, passive=False): if initiator is self: return - collection = self.get_collection(state) - collection.append_with_event(value, initiator) - def remove(self, state, value, initiator): + collection = self.get_collection(state, passive=passive) + if collection is PASSIVE_NORESULT: + state.get_pending(self).append(value) + self.fire_append_event(state, value, initiator) + else: + collection.append_with_event(value, initiator) + + def remove(self, state, value, initiator, passive=False): if initiator is self: return - collection = self.get_collection(state) - collection.remove_with_event(value, initiator) + + collection = self.get_collection(state, passive=passive) + if collection is PASSIVE_NORESULT: + state.get_pending(self).remove(value) + self.fire_remove_event(state, value, initiator) + else: + collection.remove_with_event(value, initiator) def set(self, state, value, initiator): """Set a value on the given object. @@ -470,7 +481,7 @@ class CollectionAttributeImpl(AttributeImpl): old = self.get(state) old_collection = self.get_collection(state, old) - + new_collection, user_data = self._build_collection(state) idset = util.IdentitySet @@ -494,7 +505,10 @@ class CollectionAttributeImpl(AttributeImpl): old_collection.unlink(old) def set_committed_value(self, state, value): - """Set an attribute value on the given instance and 'commit' it.""" + """Set an attribute value on the given instance and 'commit' it. + + Loads the existing collection from lazy callables in all cases. + """ collection, user_data = self._build_collection(state) self._load_collection(state, value or [], emit_events=False, @@ -509,24 +523,45 @@ class CollectionAttributeImpl(AttributeImpl): return value def _build_collection(self, state): + """build a new, blank collection and return it wrapped in a CollectionAdapter.""" + user_data = self.collection_factory() collection = collections.CollectionAdapter(self, state, user_data) return collection, user_data def _load_collection(self, state, values, emit_events=True, collection=None): + """given an empty CollectionAdapter, load the collection with current values. + + Loads the collection from lazy callables in all cases. + """ + collection = collection or self.get_collection(state) if values is None: return - elif emit_events: + + appender = emit_events and collection.append_with_event or collection.append_without_event + + if self.key in state.pending: + # move 'pending' items into the newly loaded collection + added = state.pending[self.key].added_items + removed = state.pending[self.key].deleted_items for item in values: - collection.append_with_event(item) + if item not in removed: + appender(item) + for item in added: + appender(item) + del state.pending[self.key] else: for item in values: - collection.append_without_event(item) + appender(item) - def get_collection(self, state, user_data=None): + def get_collection(self, state, user_data=None, passive=False): + """retrieve the CollectionAdapter associated with the given state.""" + if user_data is None: - user_data = self.get(state) + user_data = self.get(state, passive=passive) + if user_data is PASSIVE_NORESULT: + return user_data try: return getattr(user_data, '_sa_adapter') except AttributeError: @@ -554,18 +589,18 @@ class GenericBackrefExtension(interfaces.AttributeExtension): # present when updating via a backref. impl = getattr(oldchild.__class__, self.key).impl try: - impl.remove(oldchild._state, instance, initiator) + impl.remove(oldchild._state, instance, initiator, passive=True) except (ValueError, KeyError, IndexError): pass if child is not None: - getattr(child.__class__, self.key).impl.append(child._state, instance, initiator) + getattr(child.__class__, self.key).impl.append(child._state, instance, initiator, passive=True) def append(self, instance, child, initiator): - getattr(child.__class__, self.key).impl.append(child._state, instance, initiator) + getattr(child.__class__, self.key).impl.append(child._state, instance, initiator, passive=True) def remove(self, instance, child, initiator): if child is not None: - getattr(child.__class__, self.key).impl.remove(child._state, instance, initiator) + getattr(child.__class__, self.key).impl.remove(child._state, instance, initiator, passive=True) class ClassState(object): """tracks state information at the class level.""" @@ -577,7 +612,7 @@ class ClassState(object): class InstanceState(object): """tracks state information at the instance level.""" - __slots__ = 'class_', 'obj', 'dict', 'committed_state', 'modified', 'trigger', 'callables', 'parents', 'instance_dict', '_strong_obj', 'expired_attributes' + __slots__ = 'class_', 'obj', 'dict', 'pending', 'committed_state', 'modified', 'trigger', 'callables', 'parents', 'instance_dict', '_strong_obj', 'expired_attributes' def __init__(self, obj): self.class_ = obj.__class__ @@ -588,6 +623,7 @@ class InstanceState(object): self.trigger = None self.callables = {} self.parents = {} + self.pending = {} self.instance_dict = None def __cleanup(self, ref): @@ -627,6 +663,11 @@ class InstanceState(object): finally: instance_dict._mutex.release() + def get_pending(self, attributeimpl): + if attributeimpl.key not in self.pending: + self.pending[attributeimpl.key] = PendingCollection() + return self.pending[attributeimpl.key] + def is_modified(self): if self.modified: return True @@ -654,11 +695,12 @@ class InstanceState(object): return None def __getstate__(self): - return {'committed_state':self.committed_state, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj()} + return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj()} def __setstate__(self, state): self.committed_state = state['committed_state'] self.parents = state['parents'] + self.pending = state['pending'] self.modified = state['modified'] self.obj = weakref.ref(state['instance']) self.class_ = self.obj().__class__ @@ -857,7 +899,7 @@ class AttributeHistory(object): particular instance. """ - def __init__(self, attr, state, current, passive=False): + def __init__(self, attr, state, current): self.attr = attr # get the "original" value. if a lazy load was fired when we got @@ -919,6 +961,27 @@ class AttributeHistory(object): def deleted_items(self): return list(self._deleted_items) +class PendingCollection(object): + """stores items appended and removed from a collection that has not been loaded yet. + + When the collection is loaded, the changes present in PendingCollection are applied + to produce the final result. + """ + + def __init__(self): + self.deleted_items = util.IdentitySet() + self.added_items = util.OrderedIdentitySet() + + def append(self, value): + if value in self.deleted_items: + self.deleted_items.remove(value) + self.added_items.add(value) + + def remove(self, value): + if value in self.added_items: + self.added_items.remove(value) + self.deleted_items.add(value) + def _managed_attributes(class_): """return all InstrumentedAttributes associated with the given class_ and its superclasses.""" diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 56cf58d9b5..0c49bcfc39 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -44,12 +44,12 @@ class DynamicAttributeImpl(attributes.AttributeImpl): state.dict[self.key] = c = CollectionHistory(self, state) return c - def append(self, state, value, initiator): + def append(self, state, value, initiator, passive=False): if initiator is not self: self.get_history(state)._added_items.append(value) self.fire_append_event(state, value, initiator) - def remove(self, state, value, initiator): + def remove(self, state, value, initiator, passive=False): if initiator is not self: self.get_history(state)._deleted_items.append(value) self.fire_remove_event(state, value, initiator) diff --git a/test/orm/attributes.py b/test/orm/attributes.py index 4e41f0a295..b321dc50a1 100644 --- a/test/orm/attributes.py +++ b/test/orm/attributes.py @@ -150,82 +150,9 @@ class AttributesTest(PersistTest): self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') self.assert_(len(attributes.get_history(u, 'addresses').unchanged_items()) == 1) - def test_backref(self): - class Student(object):pass - class Course(object):pass - - attributes.register_class(Student) - attributes.register_class(Course) - attributes.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students'), useobject=True) - attributes.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses'), useobject=True) - - s = Student() - c = Course() - s.courses.append(c) - self.assert_(c.students == [s]) - s.courses.remove(c) - self.assert_(c.students == []) - - (s1, s2, s3) = (Student(), Student(), Student()) - - c.students = [s1, s2, s3] - self.assert_(s2.courses == [c]) - self.assert_(s1.courses == [c]) - print "--------------------------------" - print s1 - print s1.courses - print c - print c.students - s1.courses.remove(c) - self.assert_(c.students == [s2,s3]) - class Post(object):pass - class Blog(object):pass - - attributes.register_class(Post) - attributes.register_class(Blog) - attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) - attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True) - b = Blog() - (p1, p2, p3) = (Post(), Post(), Post()) - b.posts.append(p1) - b.posts.append(p2) - b.posts.append(p3) - self.assert_(b.posts == [p1, p2, p3]) - self.assert_(p2.blog is b) - p3.blog = None - self.assert_(b.posts == [p1, p2]) - p4 = Post() - p4.blog = b - self.assert_(b.posts == [p1, p2, p4]) - - p4.blog = b - p4.blog = b - self.assert_(b.posts == [p1, p2, p4]) - - # assert no failure removing None - p5 = Post() - p5.blog = None - del p5.blog - - class Port(object):pass - class Jack(object):pass - attributes.register_class(Port) - attributes.register_class(Jack) - attributes.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port'), useobject=True) - attributes.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack'), useobject=True) - p = Port() - j = Jack() - p.jack = j - self.assert_(j.port is p) - self.assert_(p.jack is not None) - - j.port = None - self.assert_(p.jack is None) - def test_lazytrackparent(self): """test that the "hasparent" flag works properly when lazy loaders and backrefs are used""" - class Post(object):pass class Blog(object):pass @@ -449,6 +376,173 @@ class AttributesTest(PersistTest): assert True except exceptions.ArgumentError, e: assert False - + + +class BackrefTest(PersistTest): + + def test_manytomany(self): + class Student(object):pass + class Course(object):pass + + attributes.register_class(Student) + attributes.register_class(Course) + attributes.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students'), useobject=True) + attributes.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses'), useobject=True) + + s = Student() + c = Course() + s.courses.append(c) + self.assert_(c.students == [s]) + s.courses.remove(c) + self.assert_(c.students == []) + + (s1, s2, s3) = (Student(), Student(), Student()) + + c.students = [s1, s2, s3] + self.assert_(s2.courses == [c]) + self.assert_(s1.courses == [c]) + print "--------------------------------" + print s1 + print s1.courses + print c + print c.students + s1.courses.remove(c) + self.assert_(c.students == [s2,s3]) + + def test_onetomany(self): + class Post(object):pass + class Blog(object):pass + + attributes.register_class(Post) + attributes.register_class(Blog) + attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) + attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True) + b = Blog() + (p1, p2, p3) = (Post(), Post(), Post()) + b.posts.append(p1) + b.posts.append(p2) + b.posts.append(p3) + self.assert_(b.posts == [p1, p2, p3]) + self.assert_(p2.blog is b) + + p3.blog = None + self.assert_(b.posts == [p1, p2]) + p4 = Post() + p4.blog = b + self.assert_(b.posts == [p1, p2, p4]) + + p4.blog = b + p4.blog = b + self.assert_(b.posts == [p1, p2, p4]) + + # assert no failure removing None + p5 = Post() + p5.blog = None + del p5.blog + + def test_onetoone(self): + class Port(object):pass + class Jack(object):pass + attributes.register_class(Port) + attributes.register_class(Jack) + attributes.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port'), useobject=True) + attributes.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack'), useobject=True) + p = Port() + j = Jack() + p.jack = j + self.assert_(j.port is p) + self.assert_(p.jack is not None) + + j.port = None + self.assert_(p.jack is None) + +class DeferredBackrefTest(PersistTest): + def setUp(self): + global Post, Blog, called, lazy_load + + class Post(object): + def __init__(self, name): + self.name = name + def __eq__(self, other): + return other.name == self.name + + class Blog(object): + def __init__(self, name): + self.name = name + def __eq__(self, other): + return other.name == self.name + + called = [0] + + lazy_load = [] + def lazy_posts(instance): + def load(): + called[0] += 1 + return lazy_load + return load + + attributes.register_class(Post) + attributes.register_class(Blog) + attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) + attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), callable_=lazy_posts, trackparent=True, useobject=True) + + def test_lazy_add(self): + global lazy_load + + p1, p2, p3 = Post("post 1"), Post("post 2"), Post("post 3") + lazy_load = [p1, p2, p3] + + b = Blog("blog 1") + p = Post("post 4") + p.blog = b + p = Post("post 5") + p.blog = b + # setting blog doesnt call 'posts' callable + assert called[0] == 0 + + # calling backref calls the callable, populates extra posts + assert b.posts == [p1, p2, p3, Post("post 4"), Post("post 5")] + assert called[0] == 1 + + def test_lazy_remove(self): + global lazy_load + called[0] = 0 + lazy_load = [] + + b = Blog("blog 1") + p = Post("post 1") + p.blog = b + assert called[0] == 0 + + lazy_load = [p] + + p.blog = None + p2 = Post("post 2") + p2.blog = b + assert called[0] == 0 + assert b.posts == [p2] + assert called[0] == 1 + + def test_normal_load(self): + global lazy_load + lazy_load = (p1, p2, p3) = [Post("post 1"), Post("post 2"), Post("post 3")] + called[0] = 0 + + b = Blog("blog 1") + + # assign without using backref system + p2.__dict__['blog'] = b + + assert b.posts == [Post("post 1"), Post("post 2"), Post("post 3")] + assert called[0] == 1 + p2.blog = None + p4 = Post("post 4") + p4.blog = b + assert b.posts == [Post("post 1"), Post("post 3"), Post("post 4")] + assert called[0] == 1 + + called[0] = 0 + lazy_load = (p1, p2, p3) = [Post("post 1"), Post("post 2"), Post("post 3")] + if __name__ == "__main__": testbase.main() diff --git a/test/orm/lazy_relations.py b/test/orm/lazy_relations.py index 97eda30063..487eb77168 100644 --- a/test/orm/lazy_relations.py +++ b/test/orm/lazy_relations.py @@ -272,7 +272,41 @@ class LazyTest(FixtureTest): u1 = sess.query(User).get(7) assert a.user is u1 + + def test_backrefs_dont_lazyload(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user') + }) + mapper(Address, addresses) + sess = create_session() + ad = sess.query(Address).filter_by(id=1).one() + assert ad.user.id == 7 + def go(): + ad.user = None + assert ad.user is None + self.assert_sql_count(testbase.db, go, 0) + + u1 = sess.query(User).filter_by(id=7).one() + def go(): + assert ad not in u1.addresses + self.assert_sql_count(testbase.db, go, 1) + + sess.expire(u1, ['addresses']) + def go(): + assert ad in u1.addresses + self.assert_sql_count(testbase.db, go, 1) + sess.expire(u1, ['addresses']) + ad2 = Address() + def go(): + ad2.user = u1 + assert ad2.user is u1 + self.assert_sql_count(testbase.db, go, 0) + + def go(): + assert ad2 in u1.addresses + self.assert_sql_count(testbase.db, go, 1) + class M2OGetTest(FixtureTest): keep_mappers = False keep_data = True diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 36f0561567..df1b6bba1d 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -1322,15 +1322,17 @@ class RequirementsTest(AssertMixin): h1.h1s.append(H1()) s.flush() - + self.assertEquals(t1.count().scalar(), 4) + h6 = H6() h6.h1a = h1 h6.h1b = h1 h6 = H6() h6.h1a = h1 - h6.h1b = H1() - + h6.h1b = x = H1() + assert x in s + h6.h1b.h2s.append(H2()) s.flush() diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index 158813cd7f..11d7313775 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -33,8 +33,7 @@ class HistoryTest(ORMTest): u = User(_sa_session=s) a = Address(_sa_session=s) a.user = u - #print repr(a.__class__._attribute_manager.get_history(a, 'user').added_items()) - #print repr(u.addresses.added_items()) + self.assert_(u.addresses == [a]) s.commit() -- 2.47.2