from sqlalchemy import schema, util, exc as sa_exc
+
__all__ = ['collection', 'collection_adapter',
'mapped_collection', 'column_mapped_collection',
'attribute_mapped_collection']
__instrumentation_mutex = util.threading.Lock()
+class _SerializableColumnGetter(object):
+ def __init__(self, colkeys):
+ self.colkeys = colkeys
+ self.composite = len(colkeys) > 1
+
+ def __reduce__(self):
+ return _SerializableColumnGetter, (self.colkeys,)
+
+ def __call__(self, value):
+ state = instance_state(value)
+ m = _state_mapper(state)
+ key = [m._get_state_attr_by_column(
+ state, state.dict,
+ m.mapped_table.columns[k])
+ for k in self.colkeys]
+ if self.composite:
+ return tuple(key)
+ else:
+ return key[0]
def column_mapped_collection(mapping_spec):
"""A dictionary-based collection type with column-based keying.
after a session flush.
"""
+ global _state_mapper, instance_state
from sqlalchemy.orm.util import _state_mapper
from sqlalchemy.orm.attributes import instance_state
- cols = [expression._only_column_elements(q, "mapping_spec")
- for q in util.to_list(mapping_spec)]
- if len(cols) == 1:
- def keyfunc(value):
- state = instance_state(value)
- m = _state_mapper(state)
- return m._get_state_attr_by_column(state, state.dict, cols[0])
- else:
- mapping_spec = tuple(cols)
- def keyfunc(value):
- state = instance_state(value)
- m = _state_mapper(state)
- return tuple(m._get_state_attr_by_column(state, state.dict, c)
- for c in mapping_spec)
+ cols = [c.key for c in [
+ expression._only_column_elements(q, "mapping_spec")
+ for q in util.to_list(mapping_spec)]]
+ keyfunc = _SerializableColumnGetter(cols)
return lambda: MappedCollection(keyfunc)
+class _SerializableAttrGetter(object):
+ def __init__(self, name):
+ self.name = name
+ self.getter = operator.attrgetter(name)
+
+ def __call__(self, target):
+ return self.getter(target)
+
+ def __reduce__(self):
+ return _SerializableAttrGetter, (self.name, )
+
def attribute_mapped_collection(attr_name):
"""A dictionary-based collection type with attribute-based keying.
after a session flush.
"""
- return lambda: MappedCollection(operator.attrgetter(attr_name))
+ getter = _SerializableAttrGetter(attr_name)
+ return lambda: MappedCollection(getter)
def mapped_collection(keyfunc):
clear_mappers, exc as orm_exc,\
configure_mappers, Session, lazyload_all,\
lazyload, aliased
+from sqlalchemy.orm.collections import attribute_mapped_collection, \
+ column_mapped_collection
from test.lib import fixtures
from test.orm import _fixtures
from test.lib.pickleable import User, Address, Dingaling, Order, \
repickled = loads(dumps(sa_exc))
eq_(repickled.args[0], sa_exc.args[0])
+ def test_attribute_mapped_collection(self):
+ users, addresses = self.tables.users, self.tables.addresses
+
+ mapper(User, users, properties={
+ 'addresses':relationship(
+ Address,
+ collection_class=
+ attribute_mapped_collection('email_address')
+ )
+ })
+ mapper(Address, addresses)
+ u1 = User()
+ u1.addresses = {"email1":Address(email_address="email1")}
+ for loads, dumps in picklers():
+ repickled = loads(dumps(u1))
+ eq_(u1.addresses, repickled.addresses)
+ eq_(repickled.addresses['email1'],
+ Address(email_address="email1"))
+
+ def test_column_mapped_collection(self):
+ users, addresses = self.tables.users, self.tables.addresses
+
+ mapper(User, users, properties={
+ 'addresses':relationship(
+ Address,
+ collection_class=
+ column_mapped_collection(
+ addresses.c.email_address)
+ )
+ })
+ mapper(Address, addresses)
+ u1 = User()
+ u1.addresses = {
+ "email1":Address(email_address="email1"),
+ "email2":Address(email_address="email2")
+ }
+ for loads, dumps in picklers():
+ repickled = loads(dumps(u1))
+ eq_(u1.addresses, repickled.addresses)
+ eq_(repickled.addresses['email1'],
+ Address(email_address="email1"))
+
+ def test_composite_column_mapped_collection(self):
+ users, addresses = self.tables.users, self.tables.addresses
+
+ mapper(User, users, properties={
+ 'addresses':relationship(
+ Address,
+ collection_class=
+ column_mapped_collection([
+ addresses.c.id,
+ addresses.c.email_address])
+ )
+ })
+ mapper(Address, addresses)
+ u1 = User()
+ u1.addresses = {
+ (1, "email1"):Address(id=1, email_address="email1"),
+ (2, "email2"):Address(id=2, email_address="email2")
+ }
+ for loads, dumps in picklers():
+ repickled = loads(dumps(u1))
+ eq_(u1.addresses, repickled.addresses)
+ eq_(repickled.addresses[(1, 'email1')],
+ Address(id=1, email_address="email1"))
+
class PolymorphicDeferredTest(fixtures.MappedTest):
@classmethod
def define_tables(cls, metadata):