]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The collection proxies produced by associationproxy are now
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Jul 2009 17:08:38 +0000 (17:08 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Jul 2009 17:08:38 +0000 (17:08 +0000)
pickleable.  A user-defined proxy_factory however
is still not pickleable unless it defines __getstate__
and __setstate__. [ticket:1446]

CHANGES
lib/sqlalchemy/ext/associationproxy.py
test/ext/test_associationproxy.py

diff --git a/CHANGES b/CHANGES
index e4acdcfeecf5d18c29ac8ce5041cb0f67e356f21..600fec83d6c63806b60808b1dd559025360ace52 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -22,6 +22,11 @@ CHANGES
       the string "field" argument was getting treated as a 
       ClauseElement, causing various errors within more 
       complex SQL transformations.
+- ext
+   - The collection proxies produced by associationproxy are now
+     pickleable.  A user-defined proxy_factory however
+     is still not pickleable unless it defines __getstate__
+     and __setstate__. [ticket:1446]
       
 0.5.5
 =======
index 315142d8e0427119f052cc27b8f01187e6a01ed3..e126fe638d772bfad177a12b303caae9e6876ca6 100644 (file)
@@ -140,26 +140,14 @@ class AssociationProxy(object):
         return (orm.class_mapper(self.owning_class).
                 get_property(self.target_collection))
 
+    @property
     def target_class(self):
         """The class the proxy is attached to."""
         return self._get_property().mapper.class_
-    target_class = property(target_class)
 
     def _target_is_scalar(self):
         return not self._get_property().uselist
 
-    def _lazy_collection(self, weakobjref):
-        target = self.target_collection
-        del self
-        def lazy_collection():
-            obj = weakobjref()
-            if obj is None:
-                raise exceptions.InvalidRequestError(
-                    "stale association proxy, parent object has gone out of "
-                    "scope")
-            return getattr(obj, target)
-        return lazy_collection
-
     def __get__(self, obj, class_):
         if self.owning_class is None:
             self.owning_class = class_ and class_ or type(obj)
@@ -181,10 +169,10 @@ class AssociationProxy(object):
                     return proxy
             except AttributeError:
                 pass
-            proxy = self._new(self._lazy_collection(weakref.ref(obj)))
+            proxy = self._new(_lazy_collection(obj, self.target_collection))
             setattr(obj, self.key, (id(obj), proxy))
             return proxy
-
+    
     def __set__(self, obj, values):
         if self.owning_class is None:
             self.owning_class = type(obj)
@@ -238,13 +226,13 @@ class AssociationProxy(object):
             getter, setter = self.getset_factory(self.collection_class, self)
         else:
             getter, setter = self._default_getset(self.collection_class)
-
+        
         if self.collection_class is list:
-            return _AssociationList(lazy_collection, creator, getter, setter)
+            return _AssociationList(lazy_collection, creator, getter, setter, self)
         elif self.collection_class is dict:
-            return _AssociationDict(lazy_collection, creator, getter, setter)
+            return _AssociationDict(lazy_collection, creator, getter, setter, self)
         elif self.collection_class is set:
-            return _AssociationSet(lazy_collection, creator, getter, setter)
+            return _AssociationSet(lazy_collection, creator, getter, setter, self)
         else:
             raise exceptions.ArgumentError(
                 'could not guess which interface to use for '
@@ -252,6 +240,18 @@ class AssociationProxy(object):
                 'proxy_factory and proxy_bulk_set manually' %
                 (self.collection_class.__name__, self.target_collection))
 
+    def _inflate(self, proxy):
+        creator = self.creator and self.creator or self.target_class
+
+        if self.getset_factory:
+            getter, setter = self.getset_factory(self.collection_class, self)
+        else:
+            getter, setter = self._default_getset(self.collection_class)
+        
+        proxy.creator = creator
+        proxy.getter = getter
+        proxy.setter = setter
+
     def _set(self, proxy, values):
         if self.proxy_bulk_set:
             self.proxy_bulk_set(proxy, values)
@@ -266,12 +266,32 @@ class AssociationProxy(object):
                'no proxy_bulk_set supplied for custom '
                'collection_class implementation')
 
+class _lazy_collection(object):
+    def __init__(self, obj, target):
+        self.ref = weakref.ref(obj)
+        self.target = target
 
-class _AssociationList(object):
-    """Generic, converting, list-to-list proxy."""
-
-    def __init__(self, lazy_collection, creator, getter, setter):
-        """Constructs an _AssociationList.
+    def __call__(self):
+        obj = self.ref()
+        if obj is None:
+            raise exceptions.InvalidRequestError(
+               "stale association proxy, parent object has gone out of "
+               "scope")
+        return getattr(obj, self.target)
+
+    def __getstate__(self):
+        return {'obj':self.ref(), 'target':self.target}
+    
+    def __setstate__(self, state):
+        self.ref = weakref.ref(state['obj'])
+        self.target = state['target']
+
+class _AssociationCollection(object):
+    def __init__(self, lazy_collection, creator, getter, setter, parent):
+        """Constructs an _AssociationCollection.  
+        
+        This will always be a subclass of either _AssociationList,
+        _AssociationSet, or _AssociationDict.
 
         lazy_collection
           A callable returning a list-based collection of entities (usually an
@@ -296,9 +316,27 @@ class _AssociationList(object):
         self.creator = creator
         self.getter = getter
         self.setter = setter
+        self.parent = parent
 
     col = property(lambda self: self.lazy_collection())
 
+    def __len__(self):
+        return len(self.col)
+
+    def __nonzero__(self):
+        return bool(self.col)
+
+    def __getstate__(self):
+        return {'parent':self.parent, 'lazy_collection':self.lazy_collection}
+
+    def __setstate__(self, state):
+        self.parent = state['parent']
+        self.lazy_collection = state['lazy_collection']
+        self.parent._inflate(self)
+    
+class _AssociationList(_AssociationCollection):
+    """Generic, converting, list-to-list proxy."""
+
     def _create(self, value):
         return self.creator(value)
 
@@ -308,15 +346,6 @@ class _AssociationList(object):
     def _set(self, object, value):
         return self.setter(object, value)
 
-    def __len__(self):
-        return len(self.col)
-
-    def __nonzero__(self):
-        if self.col:
-            return True
-        else:
-            return False
-
     def __getitem__(self, index):
         return self._get(self.col[index])
 
@@ -494,39 +523,9 @@ class _AssociationList(object):
 
 
 _NotProvided = util.symbol('_NotProvided')
-class _AssociationDict(object):
+class _AssociationDict(_AssociationCollection):
     """Generic, converting, dict-to-dict proxy."""
 
-    def __init__(self, lazy_collection, creator, getter, setter):
-        """Constructs an _AssociationDict.
-
-        lazy_collection
-          A callable returning a dict-based collection of entities (usually an
-          object attribute managed by a SQLAlchemy relation())
-
-        creator
-          A function that creates new target entities.  Given two parameters:
-          key and value.  The assertion is assumed::
-
-            obj = creator(somekey, somevalue)
-            assert getter(somekey) == somevalue
-
-        getter
-          A function.  Given an associated object and a key, return the
-          'value'.
-
-        setter
-          A function.  Given an associated object, a key and a value, store
-          that value on the object.
-
-        """
-        self.lazy_collection = lazy_collection
-        self.creator = creator
-        self.getter = getter
-        self.setter = setter
-
-    col = property(lambda self: self.lazy_collection())
-
     def _create(self, key, value):
         return self.creator(key, value)
 
@@ -536,15 +535,6 @@ class _AssociationDict(object):
     def _set(self, object, key, value):
         return self.setter(object, key, value)
 
-    def __len__(self):
-        return len(self.col)
-
-    def __nonzero__(self):
-        if self.col:
-            return True
-        else:
-            return False
-
     def __getitem__(self, key):
         return self._get(self.col[key])
 
@@ -669,38 +659,9 @@ class _AssociationDict(object):
     del func_name, func
 
 
-class _AssociationSet(object):
+class _AssociationSet(_AssociationCollection):
     """Generic, converting, set-to-set proxy."""
 
-    def __init__(self, lazy_collection, creator, getter, setter):
-        """Constructs an _AssociationSet.
-
-        collection
-          A callable returning a set-based collection of entities (usually an
-          object attribute managed by a SQLAlchemy relation())
-
-        creator
-          A function that creates new target entities.  Given one parameter:
-          value.  The assertion is assumed::
-
-            obj = creator(somevalue)
-            assert getter(obj) == somevalue
-
-        getter
-          A function.  Given an associated object, return the 'value'.
-
-        setter
-          A function.  Given an associated object and a value, store that
-          value on the object.
-
-        """
-        self.lazy_collection = lazy_collection
-        self.creator = creator
-        self.getter = getter
-        self.setter = setter
-
-    col = property(lambda self: self.lazy_collection())
-
     def _create(self, value):
         return self.creator(value)
 
index 742f98baf870431f4b504a5b9d006f837cbfdb5b..8df449718e22369ef198672c3874186e03f9a9a8 100644 (file)
@@ -1,5 +1,8 @@
-from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
+from sqlalchemy.test.testing import eq_, assert_raises
+import copy
 import gc
+import pickle
+
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy.orm.collections import collection
@@ -15,12 +18,15 @@ class DictCollection(dict):
     def remove(self, obj):
         del self[obj.foo]
 
+
 class SetCollection(set):
     pass
 
+
 class ListCollection(list):
     pass
 
+
 class ObjectCollection(object):
     def __init__(self):
         self.values = list()
@@ -33,6 +39,21 @@ class ObjectCollection(object):
     def __iter__(self):
         return iter(self.values)
 
+
+class Parent(object):
+    kids = association_proxy('children', 'name')
+    def __init__(self, name):
+        self.name = name
+
+class Child(object):
+    def __init__(self, name):
+        self.name = name
+
+class KVChild(object):
+    def __init__(self, name, value):
+        self.name = name
+        self.value = value
+
 class _CollectionOperations(TestBase):
     def setup(self):
         collection_class = self.collection_class
@@ -837,29 +858,23 @@ class ReconstitutionTest(TestBase):
         metadata.create_all()
         parents.insert().execute(name='p1')
 
-        class Parent(object):
-            kids = association_proxy('children', 'name')
-            def __init__(self, name):
-                self.name = name
-
-        class Child(object):
-            def __init__(self, name):
-                self.name = name
-
-        mapper(Parent, parents, properties=dict(children=relation(Child)))
-        mapper(Child, children)
 
         self.metadata = metadata
-        self.Parent = Parent
-
+        self.parents = parents
+        self.children = children
+        
     def teardown(self):
         self.metadata.drop_all()
+        clear_mappers()
 
     def test_weak_identity_map(self):
+        mapper(Parent, self.parents, properties=dict(children=relation(Child)))
+        mapper(Child, self.children)
+
         session = create_session(weak_identity_map=True)
 
         def add_child(parent_name, child_name):
-            parent = (session.query(self.Parent).
+            parent = (session.query(Parent).
                       filter_by(name=parent_name)).one()
             parent.kids.append(child_name)
 
@@ -869,12 +884,14 @@ class ReconstitutionTest(TestBase):
         add_child('p1', 'c2')
 
         session.flush()
-        p = session.query(self.Parent).filter_by(name='p1').one()
+        p = session.query(Parent).filter_by(name='p1').one()
         assert set(p.kids) == set(['c1', 'c2']), p.kids
 
     def test_copy(self):
-        import copy
-        p = self.Parent('p1')
+        mapper(Parent, self.parents, properties=dict(children=relation(Child)))
+        mapper(Child, self.children)
+
+        p = Parent('p1')
         p.kids.extend(['c1', 'c2'])
         p_copy = copy.copy(p)
         del p
@@ -882,4 +899,51 @@ class ReconstitutionTest(TestBase):
 
         assert set(p_copy.kids) == set(['c1', 'c2']), p.kids
 
+    def test_pickle_list(self):
+        mapper(Parent, self.parents, properties=dict(children=relation(Child)))
+        mapper(Child, self.children)
+
+        p = Parent('p1')
+        p.kids.extend(['c1', 'c2'])
+
+        r1 = pickle.loads(pickle.dumps(p))
+        assert r1.kids == ['c1', 'c2']
 
+        r2 = pickle.loads(pickle.dumps(p.kids))
+        assert r2 == ['c1', 'c2']
+
+    def test_pickle_set(self):
+        mapper(Parent, self.parents, properties=dict(children=relation(Child, collection_class=set)))
+        mapper(Child, self.children)
+
+        p = Parent('p1')
+        p.kids.update(['c1', 'c2'])
+
+        r1 = pickle.loads(pickle.dumps(p))
+        assert r1.kids == set(['c1', 'c2'])
+
+        r2 = pickle.loads(pickle.dumps(p.kids))
+        assert r2 == set(['c1', 'c2'])
+
+    def test_pickle_dict(self):
+        mapper(Parent, self.parents, properties=dict(
+                    children=relation(KVChild, collection_class=collections.mapped_collection(PickleKeyFunc('name')))
+                ))
+        mapper(KVChild, self.children)
+
+        p = Parent('p1')
+        p.kids.update({'c1':'v1', 'c2':'v2'})
+        assert p.kids == {'c1':'c1', 'c2':'c2'}
+        
+        r1 = pickle.loads(pickle.dumps(p))
+        assert r1.kids == {'c1':'c1', 'c2':'c2'}
+
+        r2 = pickle.loads(pickle.dumps(p.kids))
+        assert r2 == {'c1':'c1', 'c2':'c2'}
+
+class PickleKeyFunc(object):
+    def __init__(self, name):
+        self.name = name
+    
+    def __call__(self, obj):
+        return getattr(obj, self.name)
\ No newline at end of file