From: Matt Chisholm Date: Sun, 27 Jul 2014 10:15:36 +0000 (+0200) Subject: fix MutableDict.coerce X-Git-Tag: rel_1_0_0b1~205^2~46^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=88f7ec6a0efe68305d5d1ee429565c1778ec6a87;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git fix MutableDict.coerce If a class inherited from MutableDict (say, for instance, to add an update() method), coerce() would give back an instance of MutableDict instead of an instance of the derived class. --- diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index 7469bcbdae..1a4568f237 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -627,10 +627,10 @@ class MutableDict(Mutable, dict): @classmethod def coerce(cls, key, value): - """Convert plain dictionary to MutableDict.""" - if not isinstance(value, MutableDict): + """Convert plain dictionary to instance of this class.""" + if not isinstance(value, cls): if isinstance(value, dict): - return MutableDict(value) + return cls(value) return Mutable.coerce(key, value) else: return value diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py index 32b3e11dd5..e81b91d93a 100644 --- a/test/ext/test_mutable.py +++ b/test/ext/test_mutable.py @@ -299,6 +299,59 @@ class MutableAssociationScalarJSONTest(_MutableDictTestBase, fixtures.MappedTest ) +class CustomMutableAssociationScalarJSONTest(_MutableDictTestBase, fixtures.MappedTest): + + CustomMutableDict = None + + @classmethod + def _type_fixture(cls): + if not(getattr(cls, 'CustomMutableDict')): + MutableDict = super(CustomMutableAssociationScalarJSONTest, cls)._type_fixture() + class CustomMutableDict(MutableDict): + pass + cls.CustomMutableDict = CustomMutableDict + return cls.CustomMutableDict + + @classmethod + def define_tables(cls, metadata): + import json + + class JSONEncodedDict(TypeDecorator): + impl = VARCHAR(50) + + def process_bind_param(self, value, dialect): + if value is not None: + value = json.dumps(value) + + return value + + def process_result_value(self, value, dialect): + if value is not None: + value = json.loads(value) + return value + + CustomMutableDict = cls._type_fixture() + CustomMutableDict.associate_with(JSONEncodedDict) + + Table('foo', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('data', JSONEncodedDict), + Column('unrelated_data', String(50)) + ) + + def test_pickle_parent(self): + # Picklers don't know how to pickle CustomMutableDict, but we aren't testing that here + pass + + def test_coerce(self): + sess = Session() + f1 = Foo(data={'a': 'b'}) + sess.add(f1) + sess.flush() + eq_(type(f1.data), self._type_fixture()) + + class _CompositeTestBase(object): @classmethod def define_tables(cls, metadata):