From: Jason Kirtland Date: Tue, 3 Jul 2007 04:31:56 +0000 (+0000) Subject: - Add coverage for dict collections, and fixes for dict support. X-Git-Tag: rel_0_4_6~133 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e898d3f13072747d7557d0939ec805e974fa3576;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Add coverage for dict collections, and fixes for dict support. - Default dict appender name now 'set' to be consistent with duck_typing, prefer to remove default appender/remover methods from dict altogether. - Add coverage for MappedCollection and OrderedDict derived dict collections - Add coverage for raw object collections - Fix OrderedDict pop() etc., [ticket:585] - Update orderinglist unit test and remove 'broken until #213' assertion --- diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index b9022ef369..4624c50c19 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -511,8 +511,10 @@ def _instrument_class(cls): op = method._sa_instrument_after assert op in ('fire_append_event', 'fire_remove_event') after = op - if before or after: + if before: methods[name] = before[0], before[1], after + elif after: + methods[name] = None, None, after # apply ABC auto-decoration to methods that need it for method, decorator in decorators.items(): @@ -775,7 +777,7 @@ def _dict_decorators(): def __setitem__(self, key, value, _sa_initiator=None): if key in self: __del(self, self[key], _sa_initiator) - __set(self, value) + __set(self, value, _sa_initiator) fn(self, key, value) _tidy(__setitem__) return __setitem__ @@ -817,9 +819,11 @@ def _dict_decorators(): def setdefault(fn): def setdefault(self, key, default=None): - if key not in self and default is not None: - __set(self, default) - return fn(self, key, default) + if key not in self: + self.__setitem__(key, default) + return default + else: + return self.__getitem__(key) _tidy(setdefault) return setdefault @@ -827,7 +831,8 @@ def _dict_decorators(): def update(fn): def update(self, other): for key in other.keys(): - self[key] = other[key] + if not self.has_key(key) or self[key] is not other[key]: + self[key] = other[key] _tidy(update) return update else: @@ -836,12 +841,15 @@ def _dict_decorators(): if __other is not Unspecified: if hasattr(__other, 'keys'): for key in __other.keys(): - self[key] = __other[key] + if key not in self or self[key] is not __other[key]: + self[key] = __other[key] else: for key, value in __other: - self[key] = value + if key not in self or self[key] is not value: + self[key] = value for key in kw: - self[key] = kw[key] + if key not in self or self[key] is not kw[key]: + self[key] = kw[key] _tidy(update) return update @@ -988,7 +996,7 @@ __interfaces = { 'iterator': '__iter__', '_decorators': _set_decorators(), }, # < 0.4 compatible naming (almost), deprecated- use decorators instead. - dict: { 'appender': 'append', + dict: { 'appender': 'set', 'remover': 'remove', 'iterator': 'itervalues', '_decorators': _dict_decorators(), }, @@ -1003,7 +1011,7 @@ class MappedCollection(dict): """A basic dictionary-based collection class. Extends dict with the minimal bag semantics that collection classes require. - "append" and "remove" are implemented in terms of a keying function: any + "set" and "remove" are implemented in terms of a keying function: any callable that takes an object and returns an object for use as a dictionary key. """ @@ -1011,11 +1019,11 @@ class MappedCollection(dict): def __init__(self, keyfunc): self.keyfunc = keyfunc - def append(self, value, _sa_initiator=None): + def set(self, value, _sa_initiator=None): key = self.keyfunc(value) self.__setitem__(key, value, _sa_initiator) - append = collection.internally_instrumented(append) - append = collection.appender(append) + set = collection.internally_instrumented(set) + set = collection.appender(set) def remove(self, value, _sa_initiator=None): key = self.keyfunc(value) diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 2e1c09c0e5..3a59cbbbd5 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -273,22 +273,18 @@ class OrderedDict(dict): else: self.update(d, **kwargs) - def keys(self): - return list(self._list) - def clear(self): self._list = [] dict.clear(self) - def update(self, d=None, **kwargs): - # d can be a dict or sequence of keys/values - if d: - if hasattr(d, 'iteritems'): - seq = d.iteritems() + def update(self, ____sequence=None, **kwargs): + if ____sequence is not None: + if hasattr(____sequence, 'keys'): + for key in ____sequence.keys(): + self.__setitem__(key, ____sequence[key]) else: - seq = d - for key, value in seq: - self.__setitem__(key, value) + for key, value in ____sequence: + self[key] = value if kwargs: self.update(kwargs) @@ -299,33 +295,46 @@ class OrderedDict(dict): else: return self.__getitem__(key) - def values(self): - return [self[key] for key in self._list] - def __iter__(self): return iter(self._list) + def values(self): + return [self[key] for key in self._list] + def itervalues(self): - return iter([self[key] for key in self._list]) + return iter(self.values()) + + def keys(self): + return list(self._list) def iterkeys(self): - return self.__iter__() + return iter(self.keys()) - def iteritems(self): - return iter([(key, self[key]) for key in self.keys()]) + def items(self): + return [(key, self[key]) for key in self.keys()] - def __delitem__(self, key): - try: - del self._list[self._list.index(key)] - except ValueError: - raise KeyError(key) - dict.__delitem__(self, key) + def iteritems(self): + return iter(self.items()) def __setitem__(self, key, object): if not self.has_key(key): self._list.append(key) dict.__setitem__(self, key, object) + def __delitem__(self, key): + dict.__delitem__(self, key) + self._list.remove(key) + + def pop(self, key): + value = dict.pop(self, key) + self._list.remove(key) + return value + + def popitem(self): + item = dict.popitem(self) + self._list.remove(item[0]) + return item + class ThreadLocal(object): """An object in which attribute access occurs only within the context of the current thread.""" diff --git a/test/ext/orderinglist.py b/test/ext/orderinglist.py index dc75d066d7..41348a6482 100644 --- a/test/ext/orderinglist.py +++ b/test/ext/orderinglist.py @@ -299,43 +299,7 @@ class OrderingListTest(PersistTest): self.assert_(srt.bullets[i].position == i) self.assert_(srt.bullets[i].text == text) - def test_replace1(self): - self._setup(ordering_list('position')) - - s1 = Slide('Slide #1') - s1.bullets = [ Bullet('1'), Bullet('2'), Bullet('3') ] - - self.assert_(len(s1.bullets) == 3) - self.assert_(s1.bullets[2].position == 2) - - session = create_session() - session.save(s1) - session.flush() - - new_bullet = Bullet('new 2') - self.assert_(new_bullet.position is None) - - # naive replacement, no database deletion should occur - # with current InstrumentedList __setitem__ semantics - s1.bullets[1] = new_bullet - - self.assert_(new_bullet.position == 1) - self.assert_(len(s1.bullets) == 3) - - id = s1.id - - session.flush() - session.clear() - - srt = session.query(Slide).get(id) - - self.assert_(srt.bullets) - self.assert_(len(srt.bullets) == 4) - - self.assert_(srt.bullets[1].text == '2') - self.assert_(srt.bullets[2].text == 'new 2') - - def test_replace2(self): + def test_replace(self): self._setup(ordering_list('position')) s1 = Slide('Slide #1') @@ -352,7 +316,7 @@ class OrderingListTest(PersistTest): self.assert_(new_bullet.position is None) # mark existing bullet as db-deleted before replacement. - session.delete(s1.bullets[1]) + #session.delete(s1.bullets[1]) s1.bullets[1] = new_bullet self.assert_(new_bullet.position == 1) diff --git a/test/orm/collection.py b/test/orm/collection.py index be81e7a4d6..c6d74e6dbf 100644 --- a/test/orm/collection.py +++ b/test/orm/collection.py @@ -1,5 +1,6 @@ import testbase from sqlalchemy import * +import sqlalchemy.exceptions as exceptions from sqlalchemy.orm import create_session, mapper, relation, \ interfaces, attributes import sqlalchemy.orm.collections as collections @@ -41,6 +42,11 @@ def entity_maker(): global _id _id += 1 return Entity(_id) +def dictable_entity(a=None, b=None, c=None): + global _id + _id += 1 + return Entity(a or str(_id), b or 'value %s' % _id, c) + class CollectionsTest(testbase.PersistTest): def _test_adapter(self, typecallable, creator=entity_maker, @@ -61,7 +67,7 @@ class CollectionsTest(testbase.PersistTest): def assert_eq(): self.assert_(to_set(direct) == canary.data) self.assert_(set(adapter) == canary.data) - assert_ne = lambda: self.assert_(set(obj.attr) != canary.data) + assert_ne = lambda: self.assert_(to_set(direct) != canary.data) e1, e2 = creator(), creator() @@ -272,7 +278,7 @@ class CollectionsTest(testbase.PersistTest): def _test_set(self, typecallable, creator=entity_maker): class Foo(object): pass - + canary = Canary() manager.register_attribute(Foo, 'attr', True, extension=canary, typecallable=typecallable) @@ -458,15 +464,351 @@ class CollectionsTest(testbase.PersistTest): self._test_adapter(SetIsh) self._test_set(SetIsh) self.assert_(getattr(SetIsh, '_sa_instrumented') == id(SetIsh)) + + def _test_dict(self, typecallable, creator=dictable_entity): + class Foo(object): + pass + + canary = Canary() + manager.register_attribute(Foo, 'attr', True, extension=canary, + typecallable=typecallable) + + obj = Foo() + adapter = collections.collection_adapter(obj.attr) + direct = obj.attr + control = dict() + + def assert_eq(): + self.assert_(set(direct.values()) == canary.data) + self.assert_(set(adapter) == canary.data) + self.assert_(direct == control) + + def addall(*values): + for item in values: + direct.set(item) + control[item.a] = item + assert_eq() + def zap(): + for item in list(adapter): + direct.remove(item) + control.clear() + # assume an 'set' method is available for tests + addall(creator()) + + if hasattr(direct, '__setitem__'): + e = creator() + direct[e.a] = e + control[e.a] = e + assert_eq() + + e = creator(e.a, e.b) + direct[e.a] = e + control[e.a] = e + assert_eq() + + if hasattr(direct, '__delitem__'): + e = creator() + addall(e) + + del direct[e.a] + del control[e.a] + assert_eq() + + e = creator() + try: + del direct[e.a] + except KeyError: + self.assert_(e not in canary.removed) + + if hasattr(direct, 'clear'): + addall(creator(), creator(), creator()) + + direct.clear() + control.clear() + assert_eq() + + direct.clear() + control.clear() + assert_eq() + + if hasattr(direct, 'pop'): + e = creator() + addall(e) + + direct.pop(e.a) + control.pop(e.a) + assert_eq() + + e = creator() + try: + direct.pop(e.a) + except KeyError: + self.assert_(e not in canary.removed) + + if hasattr(direct, 'popitem'): + zap() + e = creator() + addall(e) + + direct.popitem() + control.popitem() + assert_eq() + + if hasattr(direct, 'setdefault'): + e = creator() + + val_a = direct.setdefault(e.a, e) + val_b = control.setdefault(e.a, e) + assert_eq() + self.assert_(val_a is val_b) + + val_a = direct.setdefault(e.a, e) + val_b = control.setdefault(e.a, e) + assert_eq() + self.assert_(val_a is val_b) + + if hasattr(direct, 'update'): + e = creator() + d = dict([(ee.a, ee) for ee in [e, creator(), creator()]]) + addall(e, creator()) + + direct.update(d) + control.update(d) + assert_eq() + + kw = dict([(ee.a, ee) for ee in [e, creator()]]) + direct.update(**kw) + control.update(**kw) + assert_eq() + def test_dict(self): - def dictable_entity(a=None, b=None, c=None): - global _id - _id += 1 - return Entity(a or str(_id), b or 'value %s' % _id, c) + try: + self._test_adapter(dict, dictable_entity, + to_set=lambda c: set(c.values())) + self.assert_(False) + except exceptions.ArgumentError, e: + self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class') + + try: + self._test_dict(dict) + self.assert_(False) + except exceptions.ArgumentError, e: + self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class') + + def test_dict_subclass(self): + class MyDict(dict): + @collection.appender + @collection.internally_instrumented + def set(self, item, _sa_initiator=None): + self.__setitem__(item.a, item, _sa_initiator=_sa_initiator) + @collection.remover + @collection.internally_instrumented + def _remove(self, item, _sa_initiator=None): + self.__delitem__(item.a, _sa_initiator=_sa_initiator) + + self._test_adapter(MyDict, dictable_entity, + to_set=lambda c: set(c.values())) + self._test_dict(MyDict) + self.assert_(getattr(MyDict, '_sa_instrumented') == id(MyDict)) + + def test_dict_subclass2(self): + class MyEasyDict(collections.MappedCollection): + def __init__(self): + super(MyEasyDict, self).__init__(lambda e: e.a) + + self._test_adapter(MyEasyDict, dictable_entity, + to_set=lambda c: set(c.values())) + self._test_dict(MyEasyDict) + self.assert_(getattr(MyEasyDict, '_sa_instrumented') == id(MyEasyDict)) + + def test_dict_subclass3(self): + class MyOrdered(util.OrderedDict, collections.MappedCollection): + def __init__(self): + collections.MappedCollection.__init__(self, lambda e: e.a) + util.OrderedDict.__init__(self) + + self._test_adapter(MyOrdered, dictable_entity, + to_set=lambda c: set(c.values())) + self._test_dict(MyOrdered) + self.assert_(getattr(MyOrdered, '_sa_instrumented') == id(MyOrdered)) + + + def test_dict_duck(self): + class DictLike(object): + def __init__(self): + self.data = dict() + + @collection.appender + @collection.replaces(1) + def set(self, item): + current = self.data.get(item.a, None) + self.data[item.a] = item + return current + @collection.remover + def _remove(self, item): + del self.data[item.a] + def __setitem__(self, key, value): + self.data[key] = value + def __getitem__(self, key): + return self.data[key] + def __delitem__(self, key): + del self.data[key] + def values(self): + return self.data.values() + def __contains__(self, key): + return key in self.data + @collection.iterator + def itervalues(self): + return self.data.itervalues() + def __eq__(self, other): + return self.data == other + def __repr__(self): + return 'DictLike(%s)' % repr(self.data) + + self._test_adapter(DictLike, dictable_entity, + to_set=lambda c: set(c.itervalues())) + self._test_dict(DictLike) + self.assert_(getattr(DictLike, '_sa_instrumented') == id(DictLike)) + + def test_dict_emulates(self): + class DictIsh(object): + __emulates__ = dict + def __init__(self): + self.data = dict() + + @collection.appender + @collection.replaces(1) + def set(self, item): + current = self.data.get(item.a, None) + self.data[item.a] = item + return current + @collection.remover + def _remove(self, item): + del self.data[item.a] + def __setitem__(self, key, value): + self.data[key] = value + def __getitem__(self, key): + return self.data[key] + def __delitem__(self, key): + del self.data[key] + def values(self): + return self.data.values() + def __contains__(self, key): + return key in self.data + @collection.iterator + def itervalues(self): + return self.data.itervalues() + def __eq__(self, other): + return self.data == other + def __repr__(self): + return 'DictIsh(%s)' % repr(self.data) + + self._test_adapter(DictIsh, dictable_entity, + to_set=lambda c: set(c.itervalues())) + self._test_dict(DictIsh) + self.assert_(getattr(DictIsh, '_sa_instrumented') == id(DictIsh)) + + def _test_object(self, typecallable, creator=entity_maker): + class Foo(object): + pass + + canary = Canary() + manager.register_attribute(Foo, 'attr', True, extension=canary, + typecallable=typecallable) + + obj = Foo() + adapter = collections.collection_adapter(obj.attr) + direct = obj.attr + control = set() + + def assert_eq(): + self.assert_(set(direct) == canary.data) + self.assert_(set(adapter) == canary.data) + self.assert_(direct == control) + + # There is no API for object collections. We'll make one up + # for the purposes of the test. + e = creator() + direct.push(e) + control.add(e) + assert_eq() + + direct.zark(e) + control.remove(e) + assert_eq() - self._test_adapter(collections.attribute_mapped_collection('a'), - dictable_entity, to_set=lambda c: set(c.values())) + e = creator() + direct.maybe_zark(e) + control.discard(e) + assert_eq() + + e = creator() + direct.push(e) + control.add(e) + assert_eq() + + e = creator() + direct.maybe_zark(e) + control.discard(e) + assert_eq() + + def test_object_duck(self): + class MyCollection(object): + def __init__(self): + self.data = set() + @collection.appender + def push(self, item): + self.data.add(item) + @collection.remover + def zark(self, item): + self.data.remove(item) + @collection.removes_return() + def maybe_zark(self, item): + if item in self.data: + self.data.remove(item) + return item + @collection.iterator + def __iter__(self): + return iter(self.data) + def __eq__(self, other): + return self.data == other + + self._test_adapter(MyCollection) + self._test_object(MyCollection) + self.assert_(getattr(MyCollection, '_sa_instrumented') == + id(MyCollection)) + + def test_object_emulates(self): + class MyCollection2(object): + __emulates__ = None + def __init__(self): + self.data = set() + # looks like a list + def append(self, item): + assert False + @collection.appender + def push(self, item): + self.data.add(item) + @collection.remover + def zark(self, item): + self.data.remove(item) + @collection.removes_return() + def maybe_zark(self, item): + if item in self.data: + self.data.remove(item) + return item + @collection.iterator + def __iter__(self): + return iter(self.data) + def __eq__(self, other): + return self.data == other + + self._test_adapter(MyCollection2) + self._test_object(MyCollection2) + self.assert_(getattr(MyCollection2, '_sa_instrumented') == + id(MyCollection2)) + class DictHelpersTest(testbase.ORMTest): def define_tables(self, metadata): diff --git a/test/orm/relationships.py b/test/orm/relationships.py index 3da1097946..4bb62f3e29 100644 --- a/test/orm/relationships.py +++ b/test/orm/relationships.py @@ -777,7 +777,7 @@ class CustomCollectionsTest(testbase.ORMTest): class Bar(object): pass class AppenderDict(dict): - def append(self, item): + def set(self, item): self[id(item)] = item def remove(self, item): if id(item) in self: @@ -790,8 +790,8 @@ class CustomCollectionsTest(testbase.ORMTest): }) mapper(Bar, someothertable) f = Foo() - f.bars.append(Bar()) - f.bars.append(Bar()) + f.bars.set(Bar()) + f.bars.set(Bar()) sess = create_session() sess.save(f) sess.flush()