]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added util.IdentitySet to support [ticket:676] and [ticket:834]
authorJason Kirtland <jek@discorporate.us>
Wed, 31 Oct 2007 09:13:12 +0000 (09:13 +0000)
committerJason Kirtland <jek@discorporate.us>
Wed, 31 Oct 2007 09:13:12 +0000 (09:13 +0000)
lib/sqlalchemy/util.py
test/base/utils.py

index 6c74e115fabd54242bdc22c5caa41da7447af6f8..9317d4b9bc701bab557b2a26d94766d64503ca8a 100644 (file)
@@ -4,7 +4,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import sys, warnings, sets
+import itertools, sys, warnings, sets
 import __builtin__
 
 from sqlalchemy import exceptions, logging
@@ -524,6 +524,197 @@ class OrderedSet(Set):
     __isub__ = difference_update
 
 
+class IdentitySet(object):
+    """A set that considers only object id() for uniqueness.
+
+    This strategy has edge cases for builtin types- it's possible to have
+    two 'foo' strings in one of these sets, for example.  Use sparingly.
+    """
+
+    class _IdentityProxy(object):
+        """Proxies an object's id() as its hash and basis for equality."""
+
+        __slots__ = ('obj',)
+
+        def __init__(self, value):
+            self.obj = value
+        def __hash__(self):
+            return id(self.obj)
+        def __eq__(self, other):
+            if isinstance(other, type(self)):
+                return id(self.obj) == id(other.obj)
+            else:
+                return id(self.obj) == id(other)
+        def __ne__(self, other):
+            if isinstance(other, type(self)):
+                return id(self.obj) != id(other.obj)
+            else:
+                return id(self.obj) != id(other)
+
+    def __init__(self, iterable=None):
+        self.set = Set()
+        if iterable:
+            for o in iterable:
+                self.add(o)
+
+    def add(self, value):
+        self.set.add(_id_proxy(value))
+
+    def remove(self, value):
+        value = _id_proxy(value)
+        if value not in self:
+            raise KeyError(value.obj)
+        self.set.remove(value)
+
+    def discard(self, value):
+        self.set.discard(_id_proxy(value))
+
+    def pop(self):
+        proxied = self.set.pop()
+        return proxied.obj
+
+    def issubset(self, iterable):
+        if not isinstance(iterable, type(self)):
+            iterable = type(self)(iterable)
+        return self.set.issubset(iterable)
+    __le__ = issubset
+
+    def __lt__(self, iterable):
+        if not isinstance(iterable, type(self)):
+            iterable = type(self)(iterable)
+        return len(self) < len(iterable) and self.issubset(iterable)
+
+    def issuperset(self, iterable):
+        if not isinstance(iterable, type(self)):
+            iterable = type(self)(iterable)
+        return self.set.issuperset(iterable)
+    __ge__ = issuperset
+
+    def __gt__(self, iterable):
+        if not isinstance(iterable, type(self)):
+            iterable = type(self)(iterable)
+        return len(self) > len(iterable) and self.issuperset(iterable)
+
+    def __eq__(self, other):
+        if isinstance(other, IdentitySet):
+            return self.set == other.set
+        else:
+            return False
+
+    def __ne__(self, other):
+        if isinstance(other, IdentitySet):
+            return self.set != other.set
+        else:
+            return True
+
+    def __cmp__(self, other):
+        raise TypeError('cannot compare sets using cmp()')
+
+    def clear(self):
+        self.set.clear()
+
+    def copy(self):
+        return type(self)(self.set)
+
+    def union(self, iterable):
+        return type(self)(self.set.union(_proxyiter(iterable)))
+
+    def __or__(self, iterable):
+        if not isinstance(iterable, set_types + (IdentitySet,)):
+            return NotImplemented
+        return self.union(iterable)
+    __ror__ = union
+
+    def update(self, iterable):
+        self.set.update(_proxyiter(iterable))
+
+    def __ior__(self, iterable):
+        if not isinstance(iterable, set_types + (IdentitySet,)):
+            return NotImplemented
+        self.update(iterable)
+        return self
+
+    def difference(self, iterable):
+        return type(self)(self.set.difference(_proxyiter(iterable)))
+
+    def __sub__(self, iterable):
+        if not isinstance(iterable, set_types + (IdentitySet,)):
+            return NotImplemented
+        return self.difference(iterable)
+    __rsub__ = __sub__
+
+    def difference_update(self, iterable):
+        self.set.difference_update(_proxyiter(iterable))
+
+    def __isub__(self, iterable):
+        if not isinstance(iterable, set_types + (IdentitySet,)):
+            return NotImplemented
+        self.difference_update(iterable)
+        return self
+
+    def intersection(self, iterable):
+        return type(self)(self.set.intersection(_proxyiter(iterable)))
+
+    def __and__(self, iterable):
+        if not isinstance(iterable, set_types + (IdentitySet,)):
+            return NotImplemented
+        return self.intersection(iterable)
+    __rand__ = __and__
+
+    def intersection_update(self, iterable):
+        self.set.intersection_update(_proxyiter(iterable))
+
+    def __iand__(self, iterable):
+        if not isinstance(iterable, set_types + (IdentitySet,)):
+            return NotImplemented
+        self.intersection_update(iterable)
+        return self
+
+    def symmetric_difference(self, iterable):
+        return type(self)(self.set.symmetric_difference(_proxyiter(iterable)))
+
+    def __xor__(self, iterable):
+        if not isinstance(iterable, set_types + (IdentitySet,)):
+            return NotImplemented
+        return self.symmetric_difference(iterable)
+    __rxor__ = __xor__
+
+    def symmetric_difference_update(self, iterable):
+        self.set.symmetric_difference_update(_proxyiter(iterable))
+
+    def __ixor__(self, iterable):
+        if not isinstance(iterable, set_types + (IdentitySet,)):
+            return NotImplemented
+        self.symmetric_difference_update(iterable)
+        return self
+
+    def __iter__(self):
+        for proxy in self.set:
+            assert isinstance(proxy, self._IdentityProxy)
+            yield proxy.obj
+
+    def __len__(self):
+        return len(self.set)
+
+    def __contains__(self, value):
+        return _id_proxy(value) in self.set
+
+    def __hash__(self):
+        raise TypeError('set objects are unhashable')
+
+    def __repr__(self):
+        return '%s(%r)' % (type(self).__name__, list(self))
+
+def _proxyiter(iterable):
+    return itertools.imap(_id_proxy, iterable)
+
+def _id_proxy(item):
+    if isinstance(item, IdentitySet._IdentityProxy):
+        return item
+    else:
+        return IdentitySet._IdentityProxy(item)
+
+
 class UniqueAppender(object):
     """appends items to a collection such that only unique items
     are added."""
index 28258e9c3dd455cd9f121d68e6f4516e1e8884c9..1cfcd8fb5aba4bb65141b6d04d4b5ff5251f6547 100644 (file)
@@ -83,7 +83,144 @@ class ArgSingletonTest(unittest.TestCase):
         m1 = m2 = m3 = None
         MyClass.dispose(MyClass)
         assert len(util.ArgSingleton.instances) == 0
-        
-        
+
+class ImmutableSubclass(str):
+    pass
+
+class HashOverride(object):
+    def __init__(self, value=None):
+        self.value = value
+    def __hash__(self):
+        return hash(self.value)
+
+class EqOverride(object):
+    def __init__(self, value=None):
+        self.value = value
+    def __eq__(self, other):
+        if isinstance(other, EqOverride):
+            return self.value == other.value
+        else:
+            return False
+    def __ne__(self, other):
+        if isinstance(other, EqOverride):
+            return self.value != other.value
+        else:
+            return True
+
+class HashEqOverride(object):
+    def __init__(self, value=None):
+        self.value = value
+    def __hash__(self):
+        return hash(self.value)
+    def __eq__(self, other):
+        if isinstance(other, EqOverride):
+            return self.value == other.value
+        else:
+            return False
+    def __ne__(self, other):
+        if isinstance(other, EqOverride):
+            return self.value != other.value
+        else:
+            return True
+
+
+class IdentitySetTest(unittest.TestCase):
+    def assert_eq(self, identityset, expected_iterable):
+        found = sorted(list(identityset))
+        expected = sorted(expected_iterable)
+        self.assertEquals(found, expected)
+
+    def test_init(self):
+        ids = util.IdentitySet([1,2,3,2,1])
+        self.assert_eq(ids, [1,2,3])
+
+        ids = util.IdentitySet(ids)
+        self.assert_eq(ids, [1,2,3])
+
+        ids = util.IdentitySet()
+        self.assert_eq(ids, [])
+
+        ids = util.IdentitySet([])
+        self.assert_eq(ids, [])
+
+        ids = util.IdentitySet(ids)
+        self.assert_eq(ids, [])
+
+    def test_add(self):
+        for type_ in (object, ImmutableSubclass):
+            data = [type_(), type_()]
+            ids = util.IdentitySet()
+            for i in range(2) + range(2):
+                ids.add(data[i])
+            self.assert_eq(ids, data)
+
+        for type_ in (EqOverride, HashOverride, HashEqOverride):
+            data = [type_(1), type_(1), type_(2)]
+            ids = util.IdentitySet()
+            for i in range(3) + range(3):
+                ids.add(data[i])
+            self.assert_eq(ids, data)
+
+    def test_basic_sanity(self):
+        IdentitySet = util.IdentitySet
+
+        o1, o2, o3 = object(), object(), object()
+        ids = IdentitySet([o1])
+        ids.discard(o1)
+        ids.discard(o1)
+        ids.add(o1)
+        ids.remove(o1)
+        self.assertRaises(KeyError, ids.remove, o1)
+
+        self.assert_(ids.copy() == ids)
+        self.assert_(ids != None)
+        self.assert_(not(ids == None))
+        self.assert_(ids != IdentitySet([o1,o2,o3]))
+        ids.clear()
+        self.assert_(o1 not in ids)
+        ids.add(o2)
+        self.assert_(o2 in ids)
+        self.assert_(ids.pop() == o2)
+        ids.add(o1)
+        self.assert_(len(ids) == 1)
+
+        isuper = IdentitySet([o1,o2])
+        self.assert_(ids < isuper)
+        self.assert_(ids.issubset(isuper))
+        self.assert_(isuper.issuperset(ids))
+        self.assert_(isuper > ids)
+
+        self.assert_(ids.union(isuper) == isuper)
+        self.assert_(ids | isuper == isuper)
+        self.assert_(isuper - ids == IdentitySet([o2]))
+        self.assert_(isuper.difference(ids) == IdentitySet([o2]))
+        self.assert_(ids.intersection(isuper) == IdentitySet([o1]))
+        self.assert_(ids & isuper == IdentitySet([o1]))
+        self.assert_(ids.symmetric_difference(isuper) == IdentitySet([o2]))
+        self.assert_(ids ^ isuper == IdentitySet([o2]))
+
+        ids.update(isuper)
+        ids |= isuper
+        ids.difference_update(isuper)
+        ids -= isuper
+        ids.intersection_update(isuper)
+        ids &= isuper
+        ids.symmetric_difference_update(isuper)
+        ids ^= isuper
+
+        ids.update('foobar')
+        try:
+            ids |= 'foobar'
+            self.assert_(False)
+        except TypeError:
+            self.assert_(True)
+        s = set([o1,o2])
+        s |= ids
+        self.assert_(isinstance(s, IdentitySet))
+
+        self.assertRaises(TypeError, cmp, ids)
+        self.assertRaises(TypeError, hash, ids)
+
+
 if __name__ == "__main__":
     testbase.main()