From ed8742e6858f11d48c78fcbbad35e92834aa47f0 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 25 Jul 2009 17:08:38 +0000 Subject: [PATCH] - The collection proxies produced by associationproxy are now pickleable. A user-defined proxy_factory however is still not pickleable unless it defines __getstate__ and __setstate__. [ticket:1446] --- CHANGES | 5 + lib/sqlalchemy/ext/associationproxy.py | 167 ++++++++++--------------- test/ext/test_associationproxy.py | 100 ++++++++++++--- 3 files changed, 151 insertions(+), 121 deletions(-) diff --git a/CHANGES b/CHANGES index e4acdcfeec..600fec83d6 100644 --- a/CHANGES +++ b/CHANGES @@ -22,6 +22,11 @@ CHANGES the string "field" argument was getting treated as a ClauseElement, causing various errors within more complex SQL transformations. +- ext + - The collection proxies produced by associationproxy are now + pickleable. A user-defined proxy_factory however + is still not pickleable unless it defines __getstate__ + and __setstate__. [ticket:1446] 0.5.5 ======= diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 315142d8e0..e126fe638d 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -140,26 +140,14 @@ class AssociationProxy(object): return (orm.class_mapper(self.owning_class). get_property(self.target_collection)) + @property def target_class(self): """The class the proxy is attached to.""" return self._get_property().mapper.class_ - target_class = property(target_class) def _target_is_scalar(self): return not self._get_property().uselist - def _lazy_collection(self, weakobjref): - target = self.target_collection - del self - def lazy_collection(): - obj = weakobjref() - if obj is None: - raise exceptions.InvalidRequestError( - "stale association proxy, parent object has gone out of " - "scope") - return getattr(obj, target) - return lazy_collection - def __get__(self, obj, class_): if self.owning_class is None: self.owning_class = class_ and class_ or type(obj) @@ -181,10 +169,10 @@ class AssociationProxy(object): return proxy except AttributeError: pass - proxy = self._new(self._lazy_collection(weakref.ref(obj))) + proxy = self._new(_lazy_collection(obj, self.target_collection)) setattr(obj, self.key, (id(obj), proxy)) return proxy - + def __set__(self, obj, values): if self.owning_class is None: self.owning_class = type(obj) @@ -238,13 +226,13 @@ class AssociationProxy(object): getter, setter = self.getset_factory(self.collection_class, self) else: getter, setter = self._default_getset(self.collection_class) - + if self.collection_class is list: - return _AssociationList(lazy_collection, creator, getter, setter) + return _AssociationList(lazy_collection, creator, getter, setter, self) elif self.collection_class is dict: - return _AssociationDict(lazy_collection, creator, getter, setter) + return _AssociationDict(lazy_collection, creator, getter, setter, self) elif self.collection_class is set: - return _AssociationSet(lazy_collection, creator, getter, setter) + return _AssociationSet(lazy_collection, creator, getter, setter, self) else: raise exceptions.ArgumentError( 'could not guess which interface to use for ' @@ -252,6 +240,18 @@ class AssociationProxy(object): 'proxy_factory and proxy_bulk_set manually' % (self.collection_class.__name__, self.target_collection)) + def _inflate(self, proxy): + creator = self.creator and self.creator or self.target_class + + if self.getset_factory: + getter, setter = self.getset_factory(self.collection_class, self) + else: + getter, setter = self._default_getset(self.collection_class) + + proxy.creator = creator + proxy.getter = getter + proxy.setter = setter + def _set(self, proxy, values): if self.proxy_bulk_set: self.proxy_bulk_set(proxy, values) @@ -266,12 +266,32 @@ class AssociationProxy(object): 'no proxy_bulk_set supplied for custom ' 'collection_class implementation') +class _lazy_collection(object): + def __init__(self, obj, target): + self.ref = weakref.ref(obj) + self.target = target -class _AssociationList(object): - """Generic, converting, list-to-list proxy.""" - - def __init__(self, lazy_collection, creator, getter, setter): - """Constructs an _AssociationList. + def __call__(self): + obj = self.ref() + if obj is None: + raise exceptions.InvalidRequestError( + "stale association proxy, parent object has gone out of " + "scope") + return getattr(obj, self.target) + + def __getstate__(self): + return {'obj':self.ref(), 'target':self.target} + + def __setstate__(self, state): + self.ref = weakref.ref(state['obj']) + self.target = state['target'] + +class _AssociationCollection(object): + def __init__(self, lazy_collection, creator, getter, setter, parent): + """Constructs an _AssociationCollection. + + This will always be a subclass of either _AssociationList, + _AssociationSet, or _AssociationDict. lazy_collection A callable returning a list-based collection of entities (usually an @@ -296,9 +316,27 @@ class _AssociationList(object): self.creator = creator self.getter = getter self.setter = setter + self.parent = parent col = property(lambda self: self.lazy_collection()) + def __len__(self): + return len(self.col) + + def __nonzero__(self): + return bool(self.col) + + def __getstate__(self): + return {'parent':self.parent, 'lazy_collection':self.lazy_collection} + + def __setstate__(self, state): + self.parent = state['parent'] + self.lazy_collection = state['lazy_collection'] + self.parent._inflate(self) + +class _AssociationList(_AssociationCollection): + """Generic, converting, list-to-list proxy.""" + def _create(self, value): return self.creator(value) @@ -308,15 +346,6 @@ class _AssociationList(object): def _set(self, object, value): return self.setter(object, value) - def __len__(self): - return len(self.col) - - def __nonzero__(self): - if self.col: - return True - else: - return False - def __getitem__(self, index): return self._get(self.col[index]) @@ -494,39 +523,9 @@ class _AssociationList(object): _NotProvided = util.symbol('_NotProvided') -class _AssociationDict(object): +class _AssociationDict(_AssociationCollection): """Generic, converting, dict-to-dict proxy.""" - def __init__(self, lazy_collection, creator, getter, setter): - """Constructs an _AssociationDict. - - lazy_collection - A callable returning a dict-based collection of entities (usually an - object attribute managed by a SQLAlchemy relation()) - - creator - A function that creates new target entities. Given two parameters: - key and value. The assertion is assumed:: - - obj = creator(somekey, somevalue) - assert getter(somekey) == somevalue - - getter - A function. Given an associated object and a key, return the - 'value'. - - setter - A function. Given an associated object, a key and a value, store - that value on the object. - - """ - self.lazy_collection = lazy_collection - self.creator = creator - self.getter = getter - self.setter = setter - - col = property(lambda self: self.lazy_collection()) - def _create(self, key, value): return self.creator(key, value) @@ -536,15 +535,6 @@ class _AssociationDict(object): def _set(self, object, key, value): return self.setter(object, key, value) - def __len__(self): - return len(self.col) - - def __nonzero__(self): - if self.col: - return True - else: - return False - def __getitem__(self, key): return self._get(self.col[key]) @@ -669,38 +659,9 @@ class _AssociationDict(object): del func_name, func -class _AssociationSet(object): +class _AssociationSet(_AssociationCollection): """Generic, converting, set-to-set proxy.""" - def __init__(self, lazy_collection, creator, getter, setter): - """Constructs an _AssociationSet. - - collection - A callable returning a set-based collection of entities (usually an - object attribute managed by a SQLAlchemy relation()) - - creator - A function that creates new target entities. Given one parameter: - value. The assertion is assumed:: - - obj = creator(somevalue) - assert getter(obj) == somevalue - - getter - A function. Given an associated object, return the 'value'. - - setter - A function. Given an associated object and a value, store that - value on the object. - - """ - self.lazy_collection = lazy_collection - self.creator = creator - self.getter = getter - self.setter = setter - - col = property(lambda self: self.lazy_collection()) - def _create(self, value): return self.creator(value) diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index 742f98baf8..8df449718e 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -1,5 +1,8 @@ -from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +from sqlalchemy.test.testing import eq_, assert_raises +import copy import gc +import pickle + from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.orm.collections import collection @@ -15,12 +18,15 @@ class DictCollection(dict): def remove(self, obj): del self[obj.foo] + class SetCollection(set): pass + class ListCollection(list): pass + class ObjectCollection(object): def __init__(self): self.values = list() @@ -33,6 +39,21 @@ class ObjectCollection(object): def __iter__(self): return iter(self.values) + +class Parent(object): + kids = association_proxy('children', 'name') + def __init__(self, name): + self.name = name + +class Child(object): + def __init__(self, name): + self.name = name + +class KVChild(object): + def __init__(self, name, value): + self.name = name + self.value = value + class _CollectionOperations(TestBase): def setup(self): collection_class = self.collection_class @@ -837,29 +858,23 @@ class ReconstitutionTest(TestBase): metadata.create_all() parents.insert().execute(name='p1') - class Parent(object): - kids = association_proxy('children', 'name') - def __init__(self, name): - self.name = name - - class Child(object): - def __init__(self, name): - self.name = name - - mapper(Parent, parents, properties=dict(children=relation(Child))) - mapper(Child, children) self.metadata = metadata - self.Parent = Parent - + self.parents = parents + self.children = children + def teardown(self): self.metadata.drop_all() + clear_mappers() def test_weak_identity_map(self): + mapper(Parent, self.parents, properties=dict(children=relation(Child))) + mapper(Child, self.children) + session = create_session(weak_identity_map=True) def add_child(parent_name, child_name): - parent = (session.query(self.Parent). + parent = (session.query(Parent). filter_by(name=parent_name)).one() parent.kids.append(child_name) @@ -869,12 +884,14 @@ class ReconstitutionTest(TestBase): add_child('p1', 'c2') session.flush() - p = session.query(self.Parent).filter_by(name='p1').one() + p = session.query(Parent).filter_by(name='p1').one() assert set(p.kids) == set(['c1', 'c2']), p.kids def test_copy(self): - import copy - p = self.Parent('p1') + mapper(Parent, self.parents, properties=dict(children=relation(Child))) + mapper(Child, self.children) + + p = Parent('p1') p.kids.extend(['c1', 'c2']) p_copy = copy.copy(p) del p @@ -882,4 +899,51 @@ class ReconstitutionTest(TestBase): assert set(p_copy.kids) == set(['c1', 'c2']), p.kids + def test_pickle_list(self): + mapper(Parent, self.parents, properties=dict(children=relation(Child))) + mapper(Child, self.children) + + p = Parent('p1') + p.kids.extend(['c1', 'c2']) + + r1 = pickle.loads(pickle.dumps(p)) + assert r1.kids == ['c1', 'c2'] + r2 = pickle.loads(pickle.dumps(p.kids)) + assert r2 == ['c1', 'c2'] + + def test_pickle_set(self): + mapper(Parent, self.parents, properties=dict(children=relation(Child, collection_class=set))) + mapper(Child, self.children) + + p = Parent('p1') + p.kids.update(['c1', 'c2']) + + r1 = pickle.loads(pickle.dumps(p)) + assert r1.kids == set(['c1', 'c2']) + + r2 = pickle.loads(pickle.dumps(p.kids)) + assert r2 == set(['c1', 'c2']) + + def test_pickle_dict(self): + mapper(Parent, self.parents, properties=dict( + children=relation(KVChild, collection_class=collections.mapped_collection(PickleKeyFunc('name'))) + )) + mapper(KVChild, self.children) + + p = Parent('p1') + p.kids.update({'c1':'v1', 'c2':'v2'}) + assert p.kids == {'c1':'c1', 'c2':'c2'} + + r1 = pickle.loads(pickle.dumps(p)) + assert r1.kids == {'c1':'c1', 'c2':'c2'} + + r2 = pickle.loads(pickle.dumps(p.kids)) + assert r2 == {'c1':'c1', 'c2':'c2'} + +class PickleKeyFunc(object): + def __init__(self, name): + self.name = name + + def __call__(self, obj): + return getattr(obj, self.name) \ No newline at end of file -- 2.47.2