From f83c9a3959e25e5817bae6f0ca0015f9054baf8d Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 12 Sep 2009 19:59:39 +0000 Subject: [PATCH] - Added an assertion that prevents a @validates function or other AttributeExtension from loading an unloaded collection such that internal state may be corrupted. [ticket:1526] --- CHANGES | 7 +++++- lib/sqlalchemy/orm/attributes.py | 2 ++ test/orm/test_attributes.py | 43 +++++++++++++++++++++++++++++++- 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/CHANGES b/CHANGES index 68dd7cbe02..d3a5fb2e17 100644 --- a/CHANGES +++ b/CHANGES @@ -390,7 +390,12 @@ CHANGES - Fixed bug which disallowed one side of a many-to-many bidirectional reference to declare itself as "viewonly" [ticket:1507] - + + - Added an assertion that prevents a @validates function + or other AttributeExtension from loading an unloaded + collection such that internal state may be corrupted. + [ticket:1526] + - Fixed bug which prevented two entities from mutually replacing each other's primary key values within a single flush() for some orderings of operations. [ticket:1519] diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 3ca7d83119..6fa8d54c4e 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -700,6 +700,7 @@ class CollectionAttributeImpl(AttributeImpl): collection = self.get_collection(state, dict_, passive=passive) if collection is PASSIVE_NO_RESULT: value = self.fire_append_event(state, dict_, value, initiator) + assert self.key not in dict_, "Collection was loaded during event handling." state.get_pending(self.key).append(value) else: collection.append_with_event(value, initiator) @@ -711,6 +712,7 @@ class CollectionAttributeImpl(AttributeImpl): collection = self.get_collection(state, state.dict, passive=passive) if collection is PASSIVE_NO_RESULT: self.fire_remove_event(state, dict_, value, initiator) + assert self.key not in dict_, "Collection was loaded during event handling." state.get_pending(self.key).remove(value) else: collection.remove_with_event(value, initiator) diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index b481b06791..99f0a49d39 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -4,7 +4,7 @@ from sqlalchemy.orm.collections import collection from sqlalchemy.orm.interfaces import AttributeExtension from sqlalchemy import exc as sa_exc from sqlalchemy.test import * -from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.testing import eq_, assert_raises from test.orm import _base from sqlalchemy.test.util import gc_collect from sqlalchemy.util import cmp, jython @@ -220,7 +220,48 @@ class AttributesTest(_base.ORMTest): a.email_address = 'foo@bar.com' u.addresses.append(a) self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com') + + def test_extension_lazyload_assertion(self): + class Foo(_base.BasicEntity): + pass + class Bar(_base.BasicEntity): + pass + + class ReceiveEvents(AttributeExtension): + def append(self, state, child, initiator): + state.obj().bars + return child + + def remove(self, state, child, initiator): + state.obj().bars + return child + + def set(self, state, child, oldchild, initiator): + return child + attributes.register_class(Foo) + attributes.register_class(Bar) + + bar1, bar2, bar3 = [Bar(id=1), Bar(id=2), Bar(id=3)] + def func1(**kw): + if kw.get('passive') is attributes.PASSIVE_NO_FETCH: + return attributes.PASSIVE_NO_RESULT + + return [bar1, bar2, bar3] + + attributes.register_attribute(Foo, 'bars', uselist=True, callable_=lambda o:func1, useobject=True, extension=[ReceiveEvents()]) + attributes.register_attribute(Bar, 'foos', uselist=True, useobject=True, extension=[attributes.GenericBackrefExtension('bars')]) + + x = Foo() + assert_raises(AssertionError, Bar(id=4).foos.append, x) + + x.bars + b = Bar(id=4) + b.foos.append(x) + attributes.instance_state(x).expire_attributes(['bars']) + assert_raises(AssertionError, b.foos.remove, x) + + def test_scalar_listener(self): # listeners on ScalarAttributeImpl and MutableScalarAttributeImpl aren't used normally. # test that they work for the benefit of user extensions -- 2.47.2