]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [bug] Fixed bug whereby objects using
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Mar 2012 15:24:15 +0000 (10:24 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Mar 2012 15:24:15 +0000 (10:24 -0500)
attribute_mapped_collection or
column_mapped_collection could not be
pickled.  [ticket:2409]

CHANGES
lib/sqlalchemy/orm/collections.py
test/orm/test_pickled.py

diff --git a/CHANGES b/CHANGES
index ca97051d60c1f0ff9a672161ff20e46569b413e2..74c6df65328fc1a58288ca1e566727a1189b7e90 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -14,6 +14,11 @@ CHANGES
     invokes common table expression support
     from the Core (see below). [ticket:1859]
 
+  - [bug] Fixed bug whereby objects using
+    attribute_mapped_collection or 
+    column_mapped_collection could not be
+    pickled.  [ticket:2409]
+
   - [bug] Fixed bug whereby MappedCollection
     would not get the appropriate collection
     instrumentation if it were only used
index 2eebfbca298530a30f65eef68bc5141e6145c92b..160fac8be0c3e1b9df6e1574ed0d44869f80d84a 100644 (file)
@@ -112,12 +112,32 @@ from sqlalchemy.sql import expression
 from sqlalchemy import schema, util, exc as sa_exc
 
 
+
 __all__ = ['collection', 'collection_adapter',
            'mapped_collection', 'column_mapped_collection',
            'attribute_mapped_collection']
 
 __instrumentation_mutex = util.threading.Lock()
 
+class _SerializableColumnGetter(object):
+    def __init__(self, colkeys):
+        self.colkeys = colkeys
+        self.composite = len(colkeys) > 1
+
+    def __reduce__(self):
+        return _SerializableColumnGetter, (self.colkeys,)
+
+    def __call__(self, value):
+        state = instance_state(value)
+        m = _state_mapper(state)
+        key = [m._get_state_attr_by_column(
+                        state, state.dict, 
+                        m.mapped_table.columns[k])
+                     for k in self.colkeys]
+        if self.composite:
+            return tuple(key)
+        else:
+            return key[0]
 
 def column_mapped_collection(mapping_spec):
     """A dictionary-based collection type with column-based keying.
@@ -131,25 +151,27 @@ def column_mapped_collection(mapping_spec):
     after a session flush.
 
     """
+    global _state_mapper, instance_state
     from sqlalchemy.orm.util import _state_mapper
     from sqlalchemy.orm.attributes import instance_state
 
-    cols = [expression._only_column_elements(q, "mapping_spec") 
-                for q in util.to_list(mapping_spec)]
-    if len(cols) == 1:
-        def keyfunc(value):
-            state = instance_state(value)
-            m = _state_mapper(state)
-            return m._get_state_attr_by_column(state, state.dict, cols[0])
-    else:
-        mapping_spec = tuple(cols)
-        def keyfunc(value):
-            state = instance_state(value)
-            m = _state_mapper(state)
-            return tuple(m._get_state_attr_by_column(state, state.dict, c)
-                         for c in mapping_spec)
+    cols = [c.key for c in [
+                expression._only_column_elements(q, "mapping_spec") 
+                for q in util.to_list(mapping_spec)]]
+    keyfunc = _SerializableColumnGetter(cols)
     return lambda: MappedCollection(keyfunc)
 
+class _SerializableAttrGetter(object):
+    def __init__(self, name):
+        self.name = name
+        self.getter = operator.attrgetter(name)
+
+    def __call__(self, target):
+        return self.getter(target)
+
+    def __reduce__(self):
+        return _SerializableAttrGetter, (self.name, )
+
 def attribute_mapped_collection(attr_name):
     """A dictionary-based collection type with attribute-based keying.
 
@@ -163,7 +185,8 @@ def attribute_mapped_collection(attr_name):
     after a session flush.
 
     """
-    return lambda: MappedCollection(operator.attrgetter(attr_name))
+    getter = _SerializableAttrGetter(attr_name)
+    return lambda: MappedCollection(getter)
 
 
 def mapped_collection(keyfunc):
index aa560a2e09f32e508c560c81ee34021bddc012ac..f2d292832e3658a196a9c901bbf5090cbadb18e3 100644 (file)
@@ -11,6 +11,8 @@ from sqlalchemy.orm import mapper, relationship, create_session, \
                             clear_mappers, exc as orm_exc,\
                             configure_mappers, Session, lazyload_all,\
                             lazyload, aliased
+from sqlalchemy.orm.collections import attribute_mapped_collection, \
+    column_mapped_collection
 from test.lib import fixtures
 from test.orm import _fixtures
 from test.lib.pickleable import User, Address, Dingaling, Order, \
@@ -345,6 +347,72 @@ class PickleTest(fixtures.MappedTest):
                 repickled = loads(dumps(sa_exc))
                 eq_(repickled.args[0], sa_exc.args[0])
 
+    def test_attribute_mapped_collection(self):
+        users, addresses = self.tables.users, self.tables.addresses
+
+        mapper(User, users, properties={
+            'addresses':relationship(
+                            Address, 
+                            collection_class=
+                            attribute_mapped_collection('email_address')
+                        )
+        })
+        mapper(Address, addresses)
+        u1 = User()
+        u1.addresses = {"email1":Address(email_address="email1")}
+        for loads, dumps in picklers():
+            repickled = loads(dumps(u1))
+            eq_(u1.addresses, repickled.addresses)
+            eq_(repickled.addresses['email1'], 
+                    Address(email_address="email1"))
+
+    def test_column_mapped_collection(self):
+        users, addresses = self.tables.users, self.tables.addresses
+
+        mapper(User, users, properties={
+            'addresses':relationship(
+                            Address, 
+                            collection_class=
+                            column_mapped_collection(
+                                addresses.c.email_address)
+                        )
+        })
+        mapper(Address, addresses)
+        u1 = User()
+        u1.addresses = {
+            "email1":Address(email_address="email1"),
+            "email2":Address(email_address="email2")
+        }
+        for loads, dumps in picklers():
+            repickled = loads(dumps(u1))
+            eq_(u1.addresses, repickled.addresses)
+            eq_(repickled.addresses['email1'], 
+                    Address(email_address="email1"))
+
+    def test_composite_column_mapped_collection(self):
+        users, addresses = self.tables.users, self.tables.addresses
+
+        mapper(User, users, properties={
+            'addresses':relationship(
+                            Address, 
+                            collection_class=
+                            column_mapped_collection([
+                                addresses.c.id,
+                                addresses.c.email_address])
+                        )
+        })
+        mapper(Address, addresses)
+        u1 = User()
+        u1.addresses = {
+            (1, "email1"):Address(id=1, email_address="email1"),
+            (2, "email2"):Address(id=2, email_address="email2")
+        }
+        for loads, dumps in picklers():
+            repickled = loads(dumps(u1))
+            eq_(u1.addresses, repickled.addresses)
+            eq_(repickled.addresses[(1, 'email1')], 
+                    Address(id=1, email_address="email1"))
+
 class PolymorphicDeferredTest(fixtures.MappedTest):
     @classmethod
     def define_tables(cls, metadata):