]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fix MutableDict.coerce
authorMatt Chisholm <matt@theory.org>
Sun, 27 Jul 2014 10:15:36 +0000 (12:15 +0200)
committerMatt Chisholm <matt@theory.org>
Sat, 9 Aug 2014 09:03:10 +0000 (11:03 +0200)
If a class inherited from MutableDict (say, for instance, to add an update() method), coerce() would give back an instance of MutableDict instead of an instance of the derived class.

lib/sqlalchemy/ext/mutable.py
test/ext/test_mutable.py

index 7469bcbdae7b11a5ba88651d2b060a17e6f636c3..1a4568f237f320ecc001e2db7d4e07062e09cf68 100644 (file)
@@ -627,10 +627,10 @@ class MutableDict(Mutable, dict):
 
     @classmethod
     def coerce(cls, key, value):
-        """Convert plain dictionary to MutableDict."""
-        if not isinstance(value, MutableDict):
+        """Convert plain dictionary to instance of this class."""
+        if not isinstance(value, cls):
             if isinstance(value, dict):
-                return MutableDict(value)
+                return cls(value)
             return Mutable.coerce(key, value)
         else:
             return value
index 32b3e11dd50447b4e9628b99df7491f3761597f3..e81b91d93a2fc278b9d2f297de42491aa710a22f 100644 (file)
@@ -299,6 +299,59 @@ class MutableAssociationScalarJSONTest(_MutableDictTestBase, fixtures.MappedTest
         )
 
 
+class CustomMutableAssociationScalarJSONTest(_MutableDictTestBase, fixtures.MappedTest):
+
+    CustomMutableDict = None
+
+    @classmethod
+    def _type_fixture(cls):
+        if not(getattr(cls, 'CustomMutableDict')):
+            MutableDict = super(CustomMutableAssociationScalarJSONTest, cls)._type_fixture()
+            class CustomMutableDict(MutableDict):
+                pass
+            cls.CustomMutableDict = CustomMutableDict
+        return cls.CustomMutableDict
+
+    @classmethod
+    def define_tables(cls, metadata):
+        import json
+
+        class JSONEncodedDict(TypeDecorator):
+            impl = VARCHAR(50)
+
+            def process_bind_param(self, value, dialect):
+                if value is not None:
+                    value = json.dumps(value)
+
+                return value
+
+            def process_result_value(self, value, dialect):
+                if value is not None:
+                    value = json.loads(value)
+                return value
+
+        CustomMutableDict = cls._type_fixture()
+        CustomMutableDict.associate_with(JSONEncodedDict)
+
+        Table('foo', metadata,
+            Column('id', Integer, primary_key=True,
+                            test_needs_autoincrement=True),
+            Column('data', JSONEncodedDict),
+            Column('unrelated_data', String(50))
+        )
+
+    def test_pickle_parent(self):
+        # Picklers don't know how to pickle CustomMutableDict, but we aren't testing that here
+        pass
+
+    def test_coerce(self):
+        sess = Session()
+        f1 = Foo(data={'a': 'b'})
+        sess.add(f1)
+        sess.flush()
+        eq_(type(f1.data), self._type_fixture())
+
+
 class _CompositeTestBase(object):
     @classmethod
     def define_tables(cls, metadata):