From 8007834c97e938912dfd54b342d10e4b9c0a6095 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 5 Mar 2012 10:24:15 -0500 Subject: [PATCH] - [bug] Fixed bug whereby objects using attribute_mapped_collection or column_mapped_collection could not be pickled. [ticket:2409] --- CHANGES | 5 +++ lib/sqlalchemy/orm/collections.py | 53 +++++++++++++++++------- test/orm/test_pickled.py | 68 +++++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 15 deletions(-) diff --git a/CHANGES b/CHANGES index ca97051d60..74c6df6532 100644 --- a/CHANGES +++ b/CHANGES @@ -14,6 +14,11 @@ CHANGES invokes common table expression support from the Core (see below). [ticket:1859] + - [bug] Fixed bug whereby objects using + attribute_mapped_collection or + column_mapped_collection could not be + pickled. [ticket:2409] + - [bug] Fixed bug whereby MappedCollection would not get the appropriate collection instrumentation if it were only used diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 2eebfbca29..160fac8be0 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -112,12 +112,32 @@ from sqlalchemy.sql import expression from sqlalchemy import schema, util, exc as sa_exc + __all__ = ['collection', 'collection_adapter', 'mapped_collection', 'column_mapped_collection', 'attribute_mapped_collection'] __instrumentation_mutex = util.threading.Lock() +class _SerializableColumnGetter(object): + def __init__(self, colkeys): + self.colkeys = colkeys + self.composite = len(colkeys) > 1 + + def __reduce__(self): + return _SerializableColumnGetter, (self.colkeys,) + + def __call__(self, value): + state = instance_state(value) + m = _state_mapper(state) + key = [m._get_state_attr_by_column( + state, state.dict, + m.mapped_table.columns[k]) + for k in self.colkeys] + if self.composite: + return tuple(key) + else: + return key[0] def column_mapped_collection(mapping_spec): """A dictionary-based collection type with column-based keying. @@ -131,25 +151,27 @@ def column_mapped_collection(mapping_spec): after a session flush. """ + global _state_mapper, instance_state from sqlalchemy.orm.util import _state_mapper from sqlalchemy.orm.attributes import instance_state - cols = [expression._only_column_elements(q, "mapping_spec") - for q in util.to_list(mapping_spec)] - if len(cols) == 1: - def keyfunc(value): - state = instance_state(value) - m = _state_mapper(state) - return m._get_state_attr_by_column(state, state.dict, cols[0]) - else: - mapping_spec = tuple(cols) - def keyfunc(value): - state = instance_state(value) - m = _state_mapper(state) - return tuple(m._get_state_attr_by_column(state, state.dict, c) - for c in mapping_spec) + cols = [c.key for c in [ + expression._only_column_elements(q, "mapping_spec") + for q in util.to_list(mapping_spec)]] + keyfunc = _SerializableColumnGetter(cols) return lambda: MappedCollection(keyfunc) +class _SerializableAttrGetter(object): + def __init__(self, name): + self.name = name + self.getter = operator.attrgetter(name) + + def __call__(self, target): + return self.getter(target) + + def __reduce__(self): + return _SerializableAttrGetter, (self.name, ) + def attribute_mapped_collection(attr_name): """A dictionary-based collection type with attribute-based keying. @@ -163,7 +185,8 @@ def attribute_mapped_collection(attr_name): after a session flush. """ - return lambda: MappedCollection(operator.attrgetter(attr_name)) + getter = _SerializableAttrGetter(attr_name) + return lambda: MappedCollection(getter) def mapped_collection(keyfunc): diff --git a/test/orm/test_pickled.py b/test/orm/test_pickled.py index aa560a2e09..f2d292832e 100644 --- a/test/orm/test_pickled.py +++ b/test/orm/test_pickled.py @@ -11,6 +11,8 @@ from sqlalchemy.orm import mapper, relationship, create_session, \ clear_mappers, exc as orm_exc,\ configure_mappers, Session, lazyload_all,\ lazyload, aliased +from sqlalchemy.orm.collections import attribute_mapped_collection, \ + column_mapped_collection from test.lib import fixtures from test.orm import _fixtures from test.lib.pickleable import User, Address, Dingaling, Order, \ @@ -345,6 +347,72 @@ class PickleTest(fixtures.MappedTest): repickled = loads(dumps(sa_exc)) eq_(repickled.args[0], sa_exc.args[0]) + def test_attribute_mapped_collection(self): + users, addresses = self.tables.users, self.tables.addresses + + mapper(User, users, properties={ + 'addresses':relationship( + Address, + collection_class= + attribute_mapped_collection('email_address') + ) + }) + mapper(Address, addresses) + u1 = User() + u1.addresses = {"email1":Address(email_address="email1")} + for loads, dumps in picklers(): + repickled = loads(dumps(u1)) + eq_(u1.addresses, repickled.addresses) + eq_(repickled.addresses['email1'], + Address(email_address="email1")) + + def test_column_mapped_collection(self): + users, addresses = self.tables.users, self.tables.addresses + + mapper(User, users, properties={ + 'addresses':relationship( + Address, + collection_class= + column_mapped_collection( + addresses.c.email_address) + ) + }) + mapper(Address, addresses) + u1 = User() + u1.addresses = { + "email1":Address(email_address="email1"), + "email2":Address(email_address="email2") + } + for loads, dumps in picklers(): + repickled = loads(dumps(u1)) + eq_(u1.addresses, repickled.addresses) + eq_(repickled.addresses['email1'], + Address(email_address="email1")) + + def test_composite_column_mapped_collection(self): + users, addresses = self.tables.users, self.tables.addresses + + mapper(User, users, properties={ + 'addresses':relationship( + Address, + collection_class= + column_mapped_collection([ + addresses.c.id, + addresses.c.email_address]) + ) + }) + mapper(Address, addresses) + u1 = User() + u1.addresses = { + (1, "email1"):Address(id=1, email_address="email1"), + (2, "email2"):Address(id=2, email_address="email2") + } + for loads, dumps in picklers(): + repickled = loads(dumps(u1)) + eq_(u1.addresses, repickled.addresses) + eq_(repickled.addresses[(1, 'email1')], + Address(id=1, email_address="email1")) + class PolymorphicDeferredTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): -- 2.47.2