]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Add coverage for dict collections, and fixes for dict support.
authorJason Kirtland <jek@discorporate.us>
Tue, 3 Jul 2007 04:31:56 +0000 (04:31 +0000)
committerJason Kirtland <jek@discorporate.us>
Tue, 3 Jul 2007 04:31:56 +0000 (04:31 +0000)
- Default dict appender name now 'set' to be consistent with duck_typing,
  prefer to remove default appender/remover methods from dict altogether.
- Add coverage for MappedCollection and OrderedDict derived dict collections
- Add coverage for raw object collections
- Fix OrderedDict pop() etc., [ticket:585]
- Update orderinglist unit test and remove 'broken until #213' assertion

lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/util.py
test/ext/orderinglist.py
test/orm/collection.py
test/orm/relationships.py

index b9022ef369d341273c65c85bff07d3f78bdbb1c6..4624c50c1956ed2e291f8e6e2702b5684b4a2407 100644 (file)
@@ -511,8 +511,10 @@ def _instrument_class(cls):
             op = method._sa_instrument_after
             assert op in ('fire_append_event', 'fire_remove_event')
             after = op
-        if before or after:
+        if before:
             methods[name] = before[0], before[1], after
+        elif after:
+            methods[name] = None, None, after
 
     # apply ABC auto-decoration to methods that need it
     for method, decorator in decorators.items():
@@ -775,7 +777,7 @@ def _dict_decorators():
         def __setitem__(self, key, value, _sa_initiator=None):
             if key in self:
                 __del(self, self[key], _sa_initiator)
-            __set(self, value)
+            __set(self, value, _sa_initiator)
             fn(self, key, value)
         _tidy(__setitem__)
         return __setitem__
@@ -817,9 +819,11 @@ def _dict_decorators():
 
     def setdefault(fn):
         def setdefault(self, key, default=None):
-            if key not in self and default is not None:
-                __set(self, default)
-            return fn(self, key, default)
+            if key not in self:
+                self.__setitem__(key, default)
+                return default
+            else:
+                return self.__getitem__(key)
         _tidy(setdefault)
         return setdefault
 
@@ -827,7 +831,8 @@ def _dict_decorators():
         def update(fn):
             def update(self, other):
                 for key in other.keys():
-                    self[key] = other[key]
+                    if not self.has_key(key) or self[key] is not other[key]:
+                        self[key] = other[key]
             _tidy(update)
             return update
     else:
@@ -836,12 +841,15 @@ def _dict_decorators():
                 if __other is not Unspecified:
                     if hasattr(__other, 'keys'):
                         for key in __other.keys():
-                            self[key] = __other[key]
+                            if key not in self or self[key] is not __other[key]:
+                                self[key] = __other[key]
                     else:
                         for key, value in __other:
-                            self[key] = value
+                            if key not in self or self[key] is not value:
+                                self[key] = value
                 for key in kw:
-                    self[key] = kw[key]
+                    if key not in self or self[key] is not kw[key]:
+                        self[key] = kw[key]
             _tidy(update)
             return update
 
@@ -988,7 +996,7 @@ __interfaces = {
                   'iterator': '__iter__',
                   '_decorators': _set_decorators(), },
     # < 0.4 compatible naming (almost), deprecated- use decorators instead.
-    dict: { 'appender': 'append',
+    dict: { 'appender': 'set',
             'remover': 'remove',
             'iterator': 'itervalues',
             '_decorators': _dict_decorators(), },
@@ -1003,7 +1011,7 @@ class MappedCollection(dict):
     """A basic dictionary-based collection class.
 
     Extends dict with the minimal bag semantics that collection classes require.
-    "append" and "remove" are implemented in terms of a keying function: any
+    "set" and "remove" are implemented in terms of a keying function: any
     callable that takes an object and returns an object for use as a dictionary
     key.
     """
@@ -1011,11 +1019,11 @@ class MappedCollection(dict):
     def __init__(self, keyfunc):
         self.keyfunc = keyfunc
 
-    def append(self, value, _sa_initiator=None):
+    def set(self, value, _sa_initiator=None):
         key = self.keyfunc(value)
         self.__setitem__(key, value, _sa_initiator)
-    append = collection.internally_instrumented(append)
-    append = collection.appender(append)
+    set = collection.internally_instrumented(set)
+    set = collection.appender(set)
     
     def remove(self, value, _sa_initiator=None):
         key = self.keyfunc(value)
index 2e1c09c0e5661fbe03ce7754403f0fd730f0d387..3a59cbbbd5531ce2c0a474e3dcc4d43d47406f4f 100644 (file)
@@ -273,22 +273,18 @@ class OrderedDict(dict):
         else:
             self.update(d, **kwargs)
 
-    def keys(self):
-        return list(self._list)
-
     def clear(self):
         self._list = []
         dict.clear(self)
 
-    def update(self, d=None, **kwargs):
-        # d can be a dict or sequence of keys/values
-        if d:
-            if hasattr(d, 'iteritems'):
-                seq = d.iteritems()
+    def update(self, ____sequence=None, **kwargs):
+        if ____sequence is not None:
+            if hasattr(____sequence, 'keys'):
+                for key in ____sequence.keys():
+                    self.__setitem__(key, ____sequence[key])
             else:
-                seq = d
-            for key, value in seq:
-                self.__setitem__(key, value)
+                for key, value in ____sequence:
+                    self[key] = value
         if kwargs:
             self.update(kwargs)
 
@@ -299,33 +295,46 @@ class OrderedDict(dict):
         else:
             return self.__getitem__(key)
 
-    def values(self):
-        return [self[key] for key in self._list]
-
     def __iter__(self):
         return iter(self._list)
 
+    def values(self):
+        return [self[key] for key in self._list]
+
     def itervalues(self):
-        return iter([self[key] for key in self._list])
+        return iter(self.values())
+
+    def keys(self):
+        return list(self._list)
 
     def iterkeys(self):
-        return self.__iter__()
+        return iter(self.keys())
 
-    def iteritems(self):
-        return iter([(key, self[key]) for key in self.keys()])
+    def items(self):
+        return [(key, self[key]) for key in self.keys()]
 
-    def __delitem__(self, key):
-        try:
-            del self._list[self._list.index(key)]
-        except ValueError:
-            raise KeyError(key)
-        dict.__delitem__(self, key)
+    def iteritems(self):
+        return iter(self.items())
 
     def __setitem__(self, key, object):
         if not self.has_key(key):
             self._list.append(key)
         dict.__setitem__(self, key, object)
 
+    def __delitem__(self, key):
+        dict.__delitem__(self, key)
+        self._list.remove(key)
+
+    def pop(self, key):
+        value = dict.pop(self, key)
+        self._list.remove(key)
+        return value
+
+    def popitem(self):
+        item = dict.popitem(self)
+        self._list.remove(item[0])
+        return item
+
 class ThreadLocal(object):
     """An object in which attribute access occurs only within the context of the current thread."""
 
index dc75d066d777ae8b2d2a4eb86ff3c04f877a6208..41348a6482b320fe85554fc82d08572e2b8ab2c3 100644 (file)
@@ -299,43 +299,7 @@ class OrderingListTest(PersistTest):
             self.assert_(srt.bullets[i].position == i)
             self.assert_(srt.bullets[i].text == text)
 
-    def test_replace1(self):
-        self._setup(ordering_list('position'))
-        
-        s1 = Slide('Slide #1')
-        s1.bullets = [ Bullet('1'), Bullet('2'), Bullet('3') ]
-
-        self.assert_(len(s1.bullets) == 3)
-        self.assert_(s1.bullets[2].position == 2)
-
-        session = create_session()
-        session.save(s1)
-        session.flush()
-
-        new_bullet = Bullet('new 2')
-        self.assert_(new_bullet.position is None)
-
-        # naive replacement, no database deletion should occur
-        # with current InstrumentedList __setitem__ semantics
-        s1.bullets[1] = new_bullet
-
-        self.assert_(new_bullet.position == 1)
-        self.assert_(len(s1.bullets) == 3)
-
-        id = s1.id
-
-        session.flush()
-        session.clear()
-
-        srt = session.query(Slide).get(id)
-    
-        self.assert_(srt.bullets)
-        self.assert_(len(srt.bullets) == 4)
-
-        self.assert_(srt.bullets[1].text == '2')
-        self.assert_(srt.bullets[2].text == 'new 2')
-    
-    def test_replace2(self):
+    def test_replace(self):
         self._setup(ordering_list('position'))
 
         s1 = Slide('Slide #1')
@@ -352,7 +316,7 @@ class OrderingListTest(PersistTest):
         self.assert_(new_bullet.position is None)
 
         # mark existing bullet as db-deleted before replacement.
-        session.delete(s1.bullets[1])
+        #session.delete(s1.bullets[1])
         s1.bullets[1] = new_bullet
 
         self.assert_(new_bullet.position == 1)
index be81e7a4d689375c62efc39289c73a634f181e5d..c6d74e6dbfc421f1f10f534f2ba27ca27abb9573 100644 (file)
@@ -1,5 +1,6 @@
 import testbase
 from sqlalchemy import *
+import sqlalchemy.exceptions as exceptions
 from sqlalchemy.orm import create_session, mapper, relation, \
     interfaces, attributes
 import sqlalchemy.orm.collections as collections
@@ -41,6 +42,11 @@ def entity_maker():
     global _id
     _id += 1
     return Entity(_id)
+def dictable_entity(a=None, b=None, c=None):
+    global _id
+    _id += 1
+    return Entity(a or str(_id), b or 'value %s' % _id, c)
+
 
 class CollectionsTest(testbase.PersistTest):
     def _test_adapter(self, typecallable, creator=entity_maker,
@@ -61,7 +67,7 @@ class CollectionsTest(testbase.PersistTest):
         def assert_eq():
             self.assert_(to_set(direct) == canary.data)
             self.assert_(set(adapter) == canary.data)
-        assert_ne = lambda: self.assert_(set(obj.attr) != canary.data)
+        assert_ne = lambda: self.assert_(to_set(direct) != canary.data)
 
         e1, e2 = creator(), creator()
 
@@ -272,7 +278,7 @@ class CollectionsTest(testbase.PersistTest):
     def _test_set(self, typecallable, creator=entity_maker):
         class Foo(object):
             pass
-        
+
         canary = Canary()
         manager.register_attribute(Foo, 'attr', True, extension=canary,
                                    typecallable=typecallable)
@@ -458,15 +464,351 @@ class CollectionsTest(testbase.PersistTest):
         self._test_adapter(SetIsh)
         self._test_set(SetIsh)
         self.assert_(getattr(SetIsh, '_sa_instrumented') == id(SetIsh))
+
+    def _test_dict(self, typecallable, creator=dictable_entity):
+        class Foo(object):
+            pass
+
+        canary = Canary()
+        manager.register_attribute(Foo, 'attr', True, extension=canary,
+                                   typecallable=typecallable)
+
+        obj = Foo()
+        adapter = collections.collection_adapter(obj.attr)
+        direct = obj.attr
+        control = dict()
+
+        def assert_eq():
+            self.assert_(set(direct.values()) == canary.data)
+            self.assert_(set(adapter) == canary.data)
+            self.assert_(direct == control)
+
+        def addall(*values):
+            for item in values:
+                direct.set(item)
+                control[item.a] = item
+            assert_eq()
+        def zap():
+            for item in list(adapter):
+                direct.remove(item)
+            control.clear()
         
+        # assume an 'set' method is available for tests
+        addall(creator())
+
+        if hasattr(direct, '__setitem__'):
+            e = creator()
+            direct[e.a] = e
+            control[e.a] = e
+            assert_eq()
+
+            e = creator(e.a, e.b)
+            direct[e.a] = e
+            control[e.a] = e
+            assert_eq()
+
+        if hasattr(direct, '__delitem__'):
+            e = creator()
+            addall(e)
+
+            del direct[e.a]
+            del control[e.a]
+            assert_eq()
+
+            e = creator()
+            try:
+                del direct[e.a]
+            except KeyError:
+                self.assert_(e not in canary.removed)
+
+        if hasattr(direct, 'clear'):
+            addall(creator(), creator(), creator())
+
+            direct.clear()
+            control.clear()
+            assert_eq()
+            
+            direct.clear()
+            control.clear()
+            assert_eq()
+
+        if hasattr(direct, 'pop'):
+            e = creator()
+            addall(e)
+
+            direct.pop(e.a)
+            control.pop(e.a)
+            assert_eq()
+
+            e = creator()
+            try:
+                direct.pop(e.a)
+            except KeyError:
+                self.assert_(e not in canary.removed)
+
+        if hasattr(direct, 'popitem'):
+            zap()
+            e = creator()
+            addall(e)
+            
+            direct.popitem()
+            control.popitem()
+            assert_eq()
+
+        if hasattr(direct, 'setdefault'):
+            e = creator()
+
+            val_a = direct.setdefault(e.a, e)
+            val_b = control.setdefault(e.a, e)
+            assert_eq()
+            self.assert_(val_a is val_b)
+
+            val_a = direct.setdefault(e.a, e)
+            val_b = control.setdefault(e.a, e)
+            assert_eq()
+            self.assert_(val_a is val_b)
+
+        if hasattr(direct, 'update'):
+            e = creator()
+            d = dict([(ee.a, ee) for ee in [e, creator(), creator()]])
+            addall(e, creator())
+
+            direct.update(d)
+            control.update(d)
+            assert_eq()
+
+            kw = dict([(ee.a, ee) for ee in [e, creator()]])
+            direct.update(**kw)
+            control.update(**kw)
+            assert_eq()
+                        
     def test_dict(self):
-        def dictable_entity(a=None, b=None, c=None):
-            global _id
-            _id += 1
-            return Entity(a or str(_id), b or 'value %s' % _id, c)
+        try:
+            self._test_adapter(dict, dictable_entity,
+                               to_set=lambda c: set(c.values()))
+            self.assert_(False)
+        except exceptions.ArgumentError, e:
+            self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class')
+
+        try:
+            self._test_dict(dict)
+            self.assert_(False)
+        except exceptions.ArgumentError, e:
+            self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class')
+
+    def test_dict_subclass(self):
+        class MyDict(dict):
+            @collection.appender
+            @collection.internally_instrumented
+            def set(self, item, _sa_initiator=None):
+                self.__setitem__(item.a, item, _sa_initiator=_sa_initiator)
+            @collection.remover
+            @collection.internally_instrumented
+            def _remove(self, item, _sa_initiator=None):
+                self.__delitem__(item.a, _sa_initiator=_sa_initiator)
+
+        self._test_adapter(MyDict, dictable_entity,
+                           to_set=lambda c: set(c.values()))
+        self._test_dict(MyDict)
+        self.assert_(getattr(MyDict, '_sa_instrumented') == id(MyDict))
+
+    def test_dict_subclass2(self):
+        class MyEasyDict(collections.MappedCollection):
+            def __init__(self):
+                super(MyEasyDict, self).__init__(lambda e: e.a)
+
+        self._test_adapter(MyEasyDict, dictable_entity,
+                           to_set=lambda c: set(c.values()))
+        self._test_dict(MyEasyDict)
+        self.assert_(getattr(MyEasyDict, '_sa_instrumented') == id(MyEasyDict))
+
+    def test_dict_subclass3(self):
+        class MyOrdered(util.OrderedDict, collections.MappedCollection):
+            def __init__(self):
+                collections.MappedCollection.__init__(self, lambda e: e.a)
+                util.OrderedDict.__init__(self)
+
+        self._test_adapter(MyOrdered, dictable_entity,
+                           to_set=lambda c: set(c.values()))
+        self._test_dict(MyOrdered)
+        self.assert_(getattr(MyOrdered, '_sa_instrumented') == id(MyOrdered))
+        
+
+    def test_dict_duck(self):
+        class DictLike(object):
+            def __init__(self):
+                self.data = dict()
+
+            @collection.appender
+            @collection.replaces(1)
+            def set(self, item):
+                current = self.data.get(item.a, None)
+                self.data[item.a] = item
+                return current
+            @collection.remover
+            def _remove(self, item):
+                del self.data[item.a]
+            def __setitem__(self, key, value):
+                self.data[key] = value
+            def __getitem__(self, key):
+                return self.data[key]
+            def __delitem__(self, key):
+                del self.data[key]
+            def values(self):
+                return self.data.values()
+            def __contains__(self, key):
+                return key in self.data
+            @collection.iterator
+            def itervalues(self):
+                return self.data.itervalues()
+            def __eq__(self, other):
+                return self.data == other
+            def __repr__(self):
+                return 'DictLike(%s)' % repr(self.data)
+
+        self._test_adapter(DictLike, dictable_entity,
+                           to_set=lambda c: set(c.itervalues()))
+        self._test_dict(DictLike)
+        self.assert_(getattr(DictLike, '_sa_instrumented') == id(DictLike))
+
+    def test_dict_emulates(self):
+        class DictIsh(object):
+            __emulates__ = dict
+            def __init__(self):
+                self.data = dict()
+
+            @collection.appender
+            @collection.replaces(1)
+            def set(self, item):
+                current = self.data.get(item.a, None)
+                self.data[item.a] = item
+                return current
+            @collection.remover
+            def _remove(self, item):
+                del self.data[item.a]
+            def __setitem__(self, key, value):
+                self.data[key] = value
+            def __getitem__(self, key):
+                return self.data[key]
+            def __delitem__(self, key):
+                del self.data[key]
+            def values(self):
+                return self.data.values()
+            def __contains__(self, key):
+                return key in self.data
+            @collection.iterator
+            def itervalues(self):
+                return self.data.itervalues()
+            def __eq__(self, other):
+                return self.data == other
+            def __repr__(self):
+                return 'DictIsh(%s)' % repr(self.data)
+
+        self._test_adapter(DictIsh, dictable_entity,
+                           to_set=lambda c: set(c.itervalues()))
+        self._test_dict(DictIsh)
+        self.assert_(getattr(DictIsh, '_sa_instrumented') == id(DictIsh))
+
+    def _test_object(self, typecallable, creator=entity_maker):
+        class Foo(object):
+            pass
+        
+        canary = Canary()
+        manager.register_attribute(Foo, 'attr', True, extension=canary,
+                                   typecallable=typecallable)
+
+        obj = Foo()
+        adapter = collections.collection_adapter(obj.attr)
+        direct = obj.attr
+        control = set()
+
+        def assert_eq():
+            self.assert_(set(direct) == canary.data)
+            self.assert_(set(adapter) == canary.data)
+            self.assert_(direct == control)
+
+        # There is no API for object collections.  We'll make one up
+        # for the purposes of the test.
+        e = creator()
+        direct.push(e)
+        control.add(e)
+        assert_eq()
+
+        direct.zark(e)
+        control.remove(e)
+        assert_eq()
         
-        self._test_adapter(collections.attribute_mapped_collection('a'),
-                           dictable_entity, to_set=lambda c: set(c.values()))
+        e = creator()
+        direct.maybe_zark(e)
+        control.discard(e)
+        assert_eq()
+
+        e = creator()
+        direct.push(e)
+        control.add(e)
+        assert_eq()
+
+        e = creator()
+        direct.maybe_zark(e)
+        control.discard(e)
+        assert_eq()
+
+    def test_object_duck(self):
+        class MyCollection(object):
+            def __init__(self):
+                self.data = set()
+            @collection.appender
+            def push(self, item):
+                self.data.add(item)
+            @collection.remover
+            def zark(self, item):
+                self.data.remove(item)
+            @collection.removes_return()
+            def maybe_zark(self, item):
+                if item in self.data:
+                    self.data.remove(item)
+                    return item
+            @collection.iterator
+            def __iter__(self):
+                return iter(self.data)
+            def __eq__(self, other):
+                return self.data == other
+
+        self._test_adapter(MyCollection)
+        self._test_object(MyCollection)
+        self.assert_(getattr(MyCollection, '_sa_instrumented') ==
+                     id(MyCollection))
+
+    def test_object_emulates(self):
+        class MyCollection2(object):
+            __emulates__ = None
+            def __init__(self):
+                self.data = set()
+            # looks like a list
+            def append(self, item):
+                assert False
+            @collection.appender
+            def push(self, item):
+                self.data.add(item)
+            @collection.remover
+            def zark(self, item):
+                self.data.remove(item)
+            @collection.removes_return()
+            def maybe_zark(self, item):
+                if item in self.data:
+                    self.data.remove(item)
+                    return item
+            @collection.iterator
+            def __iter__(self):
+                return iter(self.data)
+            def __eq__(self, other):
+                return self.data == other
+
+        self._test_adapter(MyCollection2)
+        self._test_object(MyCollection2)
+        self.assert_(getattr(MyCollection2, '_sa_instrumented') ==
+                     id(MyCollection2))
+
 
 class DictHelpersTest(testbase.ORMTest):
     def define_tables(self, metadata):
index 3da10979467437829d474535016db999cb0a24b5..4bb62f3e2945dbd134152b4125de3be99a5e0f72 100644 (file)
@@ -777,7 +777,7 @@ class CustomCollectionsTest(testbase.ORMTest):
         class Bar(object):
             pass
         class AppenderDict(dict):
-            def append(self, item):
+            def set(self, item):
                 self[id(item)] = item
             def remove(self, item):
                 if id(item) in self:
@@ -790,8 +790,8 @@ class CustomCollectionsTest(testbase.ORMTest):
         })
         mapper(Bar, someothertable)
         f = Foo()
-        f.bars.append(Bar())
-        f.bars.append(Bar())
+        f.bars.set(Bar())
+        f.bars.set(Bar())
         sess = create_session()
         sess.save(f)
         sess.flush()