]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add `sqlalchemy.ext.mutable.MutableList`
authorJeong YunWon <jeong@youknowone.org>
Sat, 13 Feb 2016 09:44:30 +0000 (18:44 +0900)
committerJeong YunWon <jeong@youknowone.org>
Sat, 13 Feb 2016 12:28:46 +0000 (21:28 +0900)
lib/sqlalchemy/ext/mutable.py
test/ext/test_mutable.py

index 97f720cb4c6fb7fda8d61b66092c35d069709469..0081cf7206f343ac71f9fda14ca2f2479486bf24 100644 (file)
@@ -699,3 +699,82 @@ class MutableDict(Mutable, dict):
 
     def __setstate__(self, state):
         self.update(state)
+
+
+class MutableList(Mutable, list):
+    """A list type that implements :class:`.Mutable`.
+
+    The :class:`.MutableList` object implements a list that will
+    emit change events to the underlying mapping when the contents of
+    the list are altered, including when values are added or removed.
+
+    """
+
+    def __setitem__(self, index, value):
+        """Detect list set events and emit change events."""
+        list.__setitem__(self, index, value)
+        self.changed()
+
+    def __setslice__(self, start, end, value):
+        """Detect list set events and emit change events."""
+        list.__setslice__(self, start, end, value)
+        self.changed()
+
+    def __delitem__(self, index):
+        """Detect list del events and emit change events."""
+        list.__delitem__(self, index)
+        self.changed()
+
+    def __delslice__(self, start, end):
+        """Detect list del events and emit change events."""
+        list.__delslice__(self, start, end)
+        self.changed()
+
+    def pop(self, *arg):
+        result = list.pop(self, *arg)
+        self.changed()
+        return result
+
+    def append(self, x):
+        list.append(self, x)
+        self.changed()
+
+    def extend(self, x):
+        list.extend(self, x)
+        self.changed()
+
+    def insert(self, i, x):
+        list.insert(self, i, x)
+        self.changed()
+
+    def remove(self, i):
+        list.remove(self, i)
+        self.changed()
+
+    def clear(self):
+        list.clear(self)
+        self.changed()
+
+    def sort(self):
+        list.sort(self)
+        self.changed()
+
+    def reverse(self):
+        list.reverse(self)
+        self.changed()
+
+    @classmethod
+    def coerce(cls, index, value):
+        """Convert plain list to instance of this class."""
+        if not isinstance(value, cls):
+            if isinstance(value, list):
+                return cls(value)
+            return Mutable.coerce(index, value)
+        else:
+            return value
+
+    def __getstate__(self):
+        return list(self)
+
+    def __setstate__(self, state):
+        self[:] = state
index 602ff911af76959fb077b6b7bbc5f90afbae2dcc..7cdf9f12b13bcc2026bafc668859610649114717 100644 (file)
@@ -8,7 +8,7 @@ from sqlalchemy.testing import eq_, assert_raises_message, assert_raises
 from sqlalchemy.testing.util import picklers
 from sqlalchemy.testing import fixtures
 from sqlalchemy.ext.mutable import MutableComposite
-from sqlalchemy.ext.mutable import MutableDict
+from sqlalchemy.ext.mutable import MutableDict, MutableList
 
 
 class Foo(fixtures.BasicEntity):
@@ -261,6 +261,206 @@ class _MutableDictTestBase(_MutableDictTestFixture):
         eq_(f1.non_mutable_data, {'a': 'b'})
 
 
+class _MutableListTestFixture(object):
+    @classmethod
+    def _type_fixture(cls):
+        return MutableList
+
+    def teardown(self):
+        # clear out mapper events
+        Mapper.dispatch._clear()
+        ClassManager.dispatch._clear()
+        super(_MutableListTestFixture, self).teardown()
+
+
+class _MutableListTestBase(_MutableListTestFixture):
+    run_define_tables = 'each'
+
+    def setup_mappers(cls):
+        foo = cls.tables.foo
+
+        mapper(Foo, foo)
+
+    def test_coerce_none(self):
+        sess = Session()
+        f1 = Foo(data=None)
+        sess.add(f1)
+        sess.commit()
+        eq_(f1.data, None)
+
+    def test_coerce_raise(self):
+        assert_raises_message(
+            ValueError,
+            "Attribute 'data' does not accept objects of type",
+            Foo, data=set([1, 2, 3])
+        )
+
+    def test_in_place_mutation(self):
+        sess = Session()
+
+        f1 = Foo(data=[1, 2])
+        sess.add(f1)
+        sess.commit()
+
+        f1.data[0] = 3
+        sess.commit()
+
+        eq_(f1.data, [3, 2])
+
+    def test_in_place_slice_mutation(self):
+        sess = Session()
+
+        f1 = Foo(data=[1, 2, 3, 4])
+        sess.add(f1)
+        sess.commit()
+
+        f1.data[1:3] = 5, 6
+        sess.commit()
+
+        eq_(f1.data, [1, 5, 6, 4])
+
+    def test_del_slice(self):
+        sess = Session()
+
+        f1 = Foo(data=[1, 2, 3, 4])
+        sess.add(f1)
+        sess.commit()
+
+        del f1.data[1:3]
+        sess.commit()
+
+        eq_(f1.data, [1, 4])
+
+    def test_clear(self):
+        if not hasattr(list, 'clear'):
+            # py2 list doesn't have 'clear'
+            return
+        sess = Session()
+
+        f1 = Foo(data=[1, 2])
+        sess.add(f1)
+        sess.commit()
+
+        f1.data.clear()
+        sess.commit()
+
+        eq_(f1.data, [])
+
+    def test_pop(self):
+        sess = Session()
+
+        f1 = Foo(data=[1, 2, 3])
+        sess.add(f1)
+        sess.commit()
+
+        eq_(f1.data.pop(), 3)
+        eq_(f1.data.pop(0), 1)
+        sess.commit()
+
+        assert_raises(IndexError, f1.data.pop, 5)
+
+        eq_(f1.data, [2])
+
+    def test_append(self):
+        sess = Session()
+
+        f1 = Foo(data=[1, 2])
+        sess.add(f1)
+        sess.commit()
+
+        f1.data.append(5)
+        sess.commit()
+
+        eq_(f1.data, [1, 2, 5])
+
+    def test_extend(self):
+        sess = Session()
+
+        f1 = Foo(data=[1, 2])
+        sess.add(f1)
+        sess.commit()
+
+        f1.data.extend([5])
+        sess.commit()
+
+        eq_(f1.data, [1, 2, 5])
+
+    def test_insert(self):
+        sess = Session()
+
+        f1 = Foo(data=[1, 2])
+        sess.add(f1)
+        sess.commit()
+
+        f1.data.insert(1, 5)
+        sess.commit()
+
+        eq_(f1.data, [1, 5, 2])
+
+    def test_remove(self):
+        sess = Session()
+
+        f1 = Foo(data=[1, 2, 3])
+        sess.add(f1)
+        sess.commit()
+
+        f1.data.remove(2)
+        sess.commit()
+
+        eq_(f1.data, [1, 3])
+
+    def test_sort(self):
+        sess = Session()
+
+        f1 = Foo(data=[1, 3, 2])
+        sess.add(f1)
+        sess.commit()
+
+        f1.data.sort()
+        sess.commit()
+
+        eq_(f1.data, [1, 2, 3])
+
+    def test_reverse(self):
+        sess = Session()
+
+        f1 = Foo(data=[1, 3, 2])
+        sess.add(f1)
+        sess.commit()
+
+        f1.data.reverse()
+        sess.commit()
+
+        eq_(f1.data, [2, 3, 1])
+
+    def test_pickle_parent(self):
+        sess = Session()
+
+        f1 = Foo(data=[1, 2])
+        sess.add(f1)
+        sess.commit()
+        f1.data
+        sess.close()
+
+        for loads, dumps in picklers():
+            sess = Session()
+            f2 = loads(dumps(f1))
+            sess.add(f2)
+            f2.data[0] = 3
+            assert f2 in sess.dirty
+
+    def test_unrelated_flush(self):
+        sess = Session()
+        f1 = Foo(data=[1, 2], unrelated_data="unrelated")
+        sess.add(f1)
+        sess.flush()
+        f1.unrelated_data = "unrelated 2"
+        sess.flush()
+        f1.data[0] = 3
+        sess.commit()
+        eq_(f1.data[0], 3)
+
+
 class MutableColumnDefaultTest(_MutableDictTestFixture, fixtures.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
@@ -349,6 +549,23 @@ class MutableWithScalarJSONTest(_MutableDictTestBase, fixtures.MappedTest):
         self._test_non_mutable()
 
 
+class MutableListWithScalarPickleTest(_MutableListTestBase, fixtures.MappedTest):
+
+    @classmethod
+    def define_tables(cls, metadata):
+        MutableList = cls._type_fixture()
+
+        mutable_pickle = MutableList.as_mutable(PickleType)
+        Table('foo', metadata,
+              Column('id', Integer, primary_key=True,
+                     test_needs_autoincrement=True),
+              Column('skip', mutable_pickle),
+              Column('data', mutable_pickle),
+              Column('non_mutable_data', PickleType),
+              Column('unrelated_data', String(50))
+              )
+
+
 class MutableAssocWithAttrInheritTest(_MutableDictTestBase,
                                       fixtures.MappedTest):