]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add `sqlalchemy.ext.mutable.MutableSet`
authorJeong YunWon <jeong@youknowone.org>
Sat, 13 Feb 2016 10:20:12 +0000 (19:20 +0900)
committerJeong YunWon <jeong@youknowone.org>
Sat, 13 Feb 2016 12:28:50 +0000 (21:28 +0900)
from https://bitbucket.org/zzzeek/sqlalchemy/issues/3297

lib/sqlalchemy/ext/mutable.py
test/ext/test_mutable.py

index 0081cf7206f343ac71f9fda14ca2f2479486bf24..aa5be57ff11433363e1d9e9f5ec8b1ae0347d4bb 100644 (file)
@@ -778,3 +778,68 @@ class MutableList(Mutable, list):
 
     def __setstate__(self, state):
         self[:] = state
+
+
+class MutableSet(Mutable, set):
+    """A set type that implements :class:`.Mutable`.
+
+    The :class:`.MutableSet` object implements a list that will
+    emit change events to the underlying mapping when the contents of
+    the set are altered, including when values are added or removed.
+    """
+
+    def update(self, *arg):
+        set.update(self, *arg)
+        self.changed()
+
+    def intersection_update(self, *arg):
+        set.intersection_update(self, *arg)
+        self.changed()
+
+    def difference_update(self, *arg):
+        set.difference_update(self, *arg)
+        self.changed()
+
+    def symmetric_difference_update(self, *arg):
+        set.symmetric_difference_update(self, *arg)
+        self.changed()
+
+    def add(self, elem):
+        set.add(self, elem)
+        self.changed()
+
+    def remove(self, elem):
+        set.remove(self, elem)
+        self.changed()
+
+    def discard(self, elem):
+        set.discard(self, elem)
+        self.changed()
+
+    def pop(self, *arg):
+        result = set.pop(self, *arg)
+        self.changed()
+        return result
+
+    def clear(self):
+        set.clear(self)
+        self.changed()
+
+    @classmethod
+    def coerce(cls, index, value):
+        """Convert plain set to instance of this class."""
+        if not isinstance(value, cls):
+            if isinstance(value, set):
+                return cls(value)
+            return Mutable.coerce(index, value)
+        else:
+            return value
+
+    def __getstate__(self):
+        return set(self)
+
+    def __setstate__(self, state):
+        self.update(state)
+
+    def __reduce_ex__(self, proto):
+        return (self.__class__, (list(self), ))
index 7cdf9f12b13bcc2026bafc668859610649114717..1e1a75e7e3d377b685f4919cbaff04ce08ab9f4e 100644 (file)
@@ -8,7 +8,7 @@ from sqlalchemy.testing import eq_, assert_raises_message, assert_raises
 from sqlalchemy.testing.util import picklers
 from sqlalchemy.testing import fixtures
 from sqlalchemy.ext.mutable import MutableComposite
-from sqlalchemy.ext.mutable import MutableDict, MutableList
+from sqlalchemy.ext.mutable import MutableDict, MutableList, MutableSet
 
 
 class Foo(fixtures.BasicEntity):
@@ -461,6 +461,183 @@ class _MutableListTestBase(_MutableListTestFixture):
         eq_(f1.data[0], 3)
 
 
+class _MutableSetTestFixture(object):
+    @classmethod
+    def _type_fixture(cls):
+        return MutableSet
+
+    def teardown(self):
+        # clear out mapper events
+        Mapper.dispatch._clear()
+        ClassManager.dispatch._clear()
+        super(_MutableSetTestFixture, self).teardown()
+
+
+class _MutableSetTestBase(_MutableSetTestFixture):
+    run_define_tables = 'each'
+
+    def setup_mappers(cls):
+        foo = cls.tables.foo
+
+        mapper(Foo, foo)
+
+    def test_coerce_none(self):
+        sess = Session()
+        f1 = Foo(data=None)
+        sess.add(f1)
+        sess.commit()
+        eq_(f1.data, None)
+
+    def test_coerce_raise(self):
+        assert_raises_message(
+            ValueError,
+            "Attribute 'data' does not accept objects of type",
+            Foo, data=[1, 2, 3]
+        )
+
+    def test_clear(self):
+        sess = Session()
+
+        f1 = Foo(data=set([1, 2]))
+        sess.add(f1)
+        sess.commit()
+
+        f1.data.clear()
+        sess.commit()
+
+        eq_(f1.data, set())
+
+    def test_pop(self):
+        sess = Session()
+
+        f1 = Foo(data=set([1]))
+        sess.add(f1)
+        sess.commit()
+
+        eq_(f1.data.pop(), 1)
+        sess.commit()
+
+        assert_raises(KeyError, f1.data.pop)
+
+        eq_(f1.data, set())
+
+    def test_add(self):
+        sess = Session()
+
+        f1 = Foo(data=set([1, 2]))
+        sess.add(f1)
+        sess.commit()
+
+        f1.data.add(5)
+        sess.commit()
+
+        eq_(f1.data, set([1, 2, 5]))
+
+    def test_update(self):
+        sess = Session()
+
+        f1 = Foo(data=set([1, 2]))
+        sess.add(f1)
+        sess.commit()
+
+        f1.data.update(set([2, 5]))
+        sess.commit()
+
+        eq_(f1.data, set([1, 2, 5]))
+
+    def test_intersection_update(self):
+        sess = Session()
+
+        f1 = Foo(data=set([1, 2]))
+        sess.add(f1)
+        sess.commit()
+
+        f1.data.intersection_update(set([2, 5]))
+        sess.commit()
+
+        eq_(f1.data, set([2]))
+
+    def test_difference_update(self):
+        sess = Session()
+
+        f1 = Foo(data=set([1, 2]))
+        sess.add(f1)
+        sess.commit()
+
+        f1.data.difference_update(set([2, 5]))
+        sess.commit()
+
+        eq_(f1.data, set([1]))
+
+    def test_symmetric_difference_update(self):
+        sess = Session()
+
+        f1 = Foo(data=set([1, 2]))
+        sess.add(f1)
+        sess.commit()
+
+        f1.data.symmetric_difference_update(set([2, 5]))
+        sess.commit()
+
+        eq_(f1.data, set([1, 5]))
+
+    def test_remove(self):
+        sess = Session()
+
+        f1 = Foo(data=set([1, 2, 3]))
+        sess.add(f1)
+        sess.commit()
+
+        f1.data.remove(2)
+        sess.commit()
+
+        eq_(f1.data, set([1, 3]))
+
+    def test_discard(self):
+        sess = Session()
+
+        f1 = Foo(data=set([1, 2, 3]))
+        sess.add(f1)
+        sess.commit()
+
+        f1.data.discard(2)
+        sess.commit()
+
+        eq_(f1.data, set([1, 3]))
+
+        f1.data.discard(2)
+        sess.commit()
+
+        eq_(f1.data, set([1, 3]))
+
+    def test_pickle_parent(self):
+        sess = Session()
+
+        f1 = Foo(data=set([1, 2]))
+        sess.add(f1)
+        sess.commit()
+        f1.data
+        sess.close()
+
+        for loads, dumps in picklers():
+            sess = Session()
+            f2 = loads(dumps(f1))
+            sess.add(f2)
+            f2.data.add(3)
+            assert f2 in sess.dirty
+
+    def test_unrelated_flush(self):
+        sess = Session()
+        f1 = Foo(data=set([1, 2]), unrelated_data="unrelated")
+        sess.add(f1)
+        sess.flush()
+        f1.unrelated_data = "unrelated 2"
+        sess.flush()
+        f1.data.add(3)
+        sess.commit()
+        eq_(f1.data, set([1, 2, 3]))
+
+
 class MutableColumnDefaultTest(_MutableDictTestFixture, fixtures.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
@@ -566,6 +743,23 @@ class MutableListWithScalarPickleTest(_MutableListTestBase, fixtures.MappedTest)
               )
 
 
+class MutableSetWithScalarPickleTest(_MutableSetTestBase, fixtures.MappedTest):
+
+    @classmethod
+    def define_tables(cls, metadata):
+        MutableSet = cls._type_fixture()
+
+        mutable_pickle = MutableSet.as_mutable(PickleType)
+        Table('foo', metadata,
+              Column('id', Integer, primary_key=True,
+                     test_needs_autoincrement=True),
+              Column('skip', mutable_pickle),
+              Column('data', mutable_pickle),
+              Column('non_mutable_data', PickleType),
+              Column('unrelated_data', String(50))
+              )
+
+
 class MutableAssocWithAttrInheritTest(_MutableDictTestBase,
                                       fixtures.MappedTest):