From 661774055c58d096faf929f34abd947fb5931788 Mon Sep 17 00:00:00 2001 From: Jason Kirtland Date: Wed, 31 Oct 2007 09:13:12 +0000 Subject: [PATCH] Added util.IdentitySet to support [ticket:676] and [ticket:834] --- lib/sqlalchemy/util.py | 193 ++++++++++++++++++++++++++++++++++++++++- test/base/utils.py | 141 +++++++++++++++++++++++++++++- 2 files changed, 331 insertions(+), 3 deletions(-) diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 6c74e115fa..9317d4b9bc 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import sys, warnings, sets +import itertools, sys, warnings, sets import __builtin__ from sqlalchemy import exceptions, logging @@ -524,6 +524,197 @@ class OrderedSet(Set): __isub__ = difference_update +class IdentitySet(object): + """A set that considers only object id() for uniqueness. + + This strategy has edge cases for builtin types- it's possible to have + two 'foo' strings in one of these sets, for example. Use sparingly. + """ + + class _IdentityProxy(object): + """Proxies an object's id() as its hash and basis for equality.""" + + __slots__ = ('obj',) + + def __init__(self, value): + self.obj = value + def __hash__(self): + return id(self.obj) + def __eq__(self, other): + if isinstance(other, type(self)): + return id(self.obj) == id(other.obj) + else: + return id(self.obj) == id(other) + def __ne__(self, other): + if isinstance(other, type(self)): + return id(self.obj) != id(other.obj) + else: + return id(self.obj) != id(other) + + def __init__(self, iterable=None): + self.set = Set() + if iterable: + for o in iterable: + self.add(o) + + def add(self, value): + self.set.add(_id_proxy(value)) + + def remove(self, value): + value = _id_proxy(value) + if value not in self: + raise KeyError(value.obj) + self.set.remove(value) + + def discard(self, value): + self.set.discard(_id_proxy(value)) + + def pop(self): + proxied = self.set.pop() + return proxied.obj + + def issubset(self, iterable): + if not isinstance(iterable, type(self)): + iterable = type(self)(iterable) + return self.set.issubset(iterable) + __le__ = issubset + + def __lt__(self, iterable): + if not isinstance(iterable, type(self)): + iterable = type(self)(iterable) + return len(self) < len(iterable) and self.issubset(iterable) + + def issuperset(self, iterable): + if not isinstance(iterable, type(self)): + iterable = type(self)(iterable) + return self.set.issuperset(iterable) + __ge__ = issuperset + + def __gt__(self, iterable): + if not isinstance(iterable, type(self)): + iterable = type(self)(iterable) + return len(self) > len(iterable) and self.issuperset(iterable) + + def __eq__(self, other): + if isinstance(other, IdentitySet): + return self.set == other.set + else: + return False + + def __ne__(self, other): + if isinstance(other, IdentitySet): + return self.set != other.set + else: + return True + + def __cmp__(self, other): + raise TypeError('cannot compare sets using cmp()') + + def clear(self): + self.set.clear() + + def copy(self): + return type(self)(self.set) + + def union(self, iterable): + return type(self)(self.set.union(_proxyiter(iterable))) + + def __or__(self, iterable): + if not isinstance(iterable, set_types + (IdentitySet,)): + return NotImplemented + return self.union(iterable) + __ror__ = union + + def update(self, iterable): + self.set.update(_proxyiter(iterable)) + + def __ior__(self, iterable): + if not isinstance(iterable, set_types + (IdentitySet,)): + return NotImplemented + self.update(iterable) + return self + + def difference(self, iterable): + return type(self)(self.set.difference(_proxyiter(iterable))) + + def __sub__(self, iterable): + if not isinstance(iterable, set_types + (IdentitySet,)): + return NotImplemented + return self.difference(iterable) + __rsub__ = __sub__ + + def difference_update(self, iterable): + self.set.difference_update(_proxyiter(iterable)) + + def __isub__(self, iterable): + if not isinstance(iterable, set_types + (IdentitySet,)): + return NotImplemented + self.difference_update(iterable) + return self + + def intersection(self, iterable): + return type(self)(self.set.intersection(_proxyiter(iterable))) + + def __and__(self, iterable): + if not isinstance(iterable, set_types + (IdentitySet,)): + return NotImplemented + return self.intersection(iterable) + __rand__ = __and__ + + def intersection_update(self, iterable): + self.set.intersection_update(_proxyiter(iterable)) + + def __iand__(self, iterable): + if not isinstance(iterable, set_types + (IdentitySet,)): + return NotImplemented + self.intersection_update(iterable) + return self + + def symmetric_difference(self, iterable): + return type(self)(self.set.symmetric_difference(_proxyiter(iterable))) + + def __xor__(self, iterable): + if not isinstance(iterable, set_types + (IdentitySet,)): + return NotImplemented + return self.symmetric_difference(iterable) + __rxor__ = __xor__ + + def symmetric_difference_update(self, iterable): + self.set.symmetric_difference_update(_proxyiter(iterable)) + + def __ixor__(self, iterable): + if not isinstance(iterable, set_types + (IdentitySet,)): + return NotImplemented + self.symmetric_difference_update(iterable) + return self + + def __iter__(self): + for proxy in self.set: + assert isinstance(proxy, self._IdentityProxy) + yield proxy.obj + + def __len__(self): + return len(self.set) + + def __contains__(self, value): + return _id_proxy(value) in self.set + + def __hash__(self): + raise TypeError('set objects are unhashable') + + def __repr__(self): + return '%s(%r)' % (type(self).__name__, list(self)) + +def _proxyiter(iterable): + return itertools.imap(_id_proxy, iterable) + +def _id_proxy(item): + if isinstance(item, IdentitySet._IdentityProxy): + return item + else: + return IdentitySet._IdentityProxy(item) + + class UniqueAppender(object): """appends items to a collection such that only unique items are added.""" diff --git a/test/base/utils.py b/test/base/utils.py index 28258e9c3d..1cfcd8fb5a 100644 --- a/test/base/utils.py +++ b/test/base/utils.py @@ -83,7 +83,144 @@ class ArgSingletonTest(unittest.TestCase): m1 = m2 = m3 = None MyClass.dispose(MyClass) assert len(util.ArgSingleton.instances) == 0 - - + +class ImmutableSubclass(str): + pass + +class HashOverride(object): + def __init__(self, value=None): + self.value = value + def __hash__(self): + return hash(self.value) + +class EqOverride(object): + def __init__(self, value=None): + self.value = value + def __eq__(self, other): + if isinstance(other, EqOverride): + return self.value == other.value + else: + return False + def __ne__(self, other): + if isinstance(other, EqOverride): + return self.value != other.value + else: + return True + +class HashEqOverride(object): + def __init__(self, value=None): + self.value = value + def __hash__(self): + return hash(self.value) + def __eq__(self, other): + if isinstance(other, EqOverride): + return self.value == other.value + else: + return False + def __ne__(self, other): + if isinstance(other, EqOverride): + return self.value != other.value + else: + return True + + +class IdentitySetTest(unittest.TestCase): + def assert_eq(self, identityset, expected_iterable): + found = sorted(list(identityset)) + expected = sorted(expected_iterable) + self.assertEquals(found, expected) + + def test_init(self): + ids = util.IdentitySet([1,2,3,2,1]) + self.assert_eq(ids, [1,2,3]) + + ids = util.IdentitySet(ids) + self.assert_eq(ids, [1,2,3]) + + ids = util.IdentitySet() + self.assert_eq(ids, []) + + ids = util.IdentitySet([]) + self.assert_eq(ids, []) + + ids = util.IdentitySet(ids) + self.assert_eq(ids, []) + + def test_add(self): + for type_ in (object, ImmutableSubclass): + data = [type_(), type_()] + ids = util.IdentitySet() + for i in range(2) + range(2): + ids.add(data[i]) + self.assert_eq(ids, data) + + for type_ in (EqOverride, HashOverride, HashEqOverride): + data = [type_(1), type_(1), type_(2)] + ids = util.IdentitySet() + for i in range(3) + range(3): + ids.add(data[i]) + self.assert_eq(ids, data) + + def test_basic_sanity(self): + IdentitySet = util.IdentitySet + + o1, o2, o3 = object(), object(), object() + ids = IdentitySet([o1]) + ids.discard(o1) + ids.discard(o1) + ids.add(o1) + ids.remove(o1) + self.assertRaises(KeyError, ids.remove, o1) + + self.assert_(ids.copy() == ids) + self.assert_(ids != None) + self.assert_(not(ids == None)) + self.assert_(ids != IdentitySet([o1,o2,o3])) + ids.clear() + self.assert_(o1 not in ids) + ids.add(o2) + self.assert_(o2 in ids) + self.assert_(ids.pop() == o2) + ids.add(o1) + self.assert_(len(ids) == 1) + + isuper = IdentitySet([o1,o2]) + self.assert_(ids < isuper) + self.assert_(ids.issubset(isuper)) + self.assert_(isuper.issuperset(ids)) + self.assert_(isuper > ids) + + self.assert_(ids.union(isuper) == isuper) + self.assert_(ids | isuper == isuper) + self.assert_(isuper - ids == IdentitySet([o2])) + self.assert_(isuper.difference(ids) == IdentitySet([o2])) + self.assert_(ids.intersection(isuper) == IdentitySet([o1])) + self.assert_(ids & isuper == IdentitySet([o1])) + self.assert_(ids.symmetric_difference(isuper) == IdentitySet([o2])) + self.assert_(ids ^ isuper == IdentitySet([o2])) + + ids.update(isuper) + ids |= isuper + ids.difference_update(isuper) + ids -= isuper + ids.intersection_update(isuper) + ids &= isuper + ids.symmetric_difference_update(isuper) + ids ^= isuper + + ids.update('foobar') + try: + ids |= 'foobar' + self.assert_(False) + except TypeError: + self.assert_(True) + s = set([o1,o2]) + s |= ids + self.assert_(isinstance(s, IdentitySet)) + + self.assertRaises(TypeError, cmp, ids) + self.assertRaises(TypeError, hash, ids) + + if __name__ == "__main__": testbase.main() -- 2.47.2