# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import sys, warnings, sets
+import itertools, sys, warnings, sets
import __builtin__
from sqlalchemy import exceptions, logging
__isub__ = difference_update
+class IdentitySet(object):
+ """A set that considers only object id() for uniqueness.
+
+ This strategy has edge cases for builtin types- it's possible to have
+ two 'foo' strings in one of these sets, for example. Use sparingly.
+ """
+
+ class _IdentityProxy(object):
+ """Proxies an object's id() as its hash and basis for equality."""
+
+ __slots__ = ('obj',)
+
+ def __init__(self, value):
+ self.obj = value
+ def __hash__(self):
+ return id(self.obj)
+ def __eq__(self, other):
+ if isinstance(other, type(self)):
+ return id(self.obj) == id(other.obj)
+ else:
+ return id(self.obj) == id(other)
+ def __ne__(self, other):
+ if isinstance(other, type(self)):
+ return id(self.obj) != id(other.obj)
+ else:
+ return id(self.obj) != id(other)
+
+ def __init__(self, iterable=None):
+ self.set = Set()
+ if iterable:
+ for o in iterable:
+ self.add(o)
+
+ def add(self, value):
+ self.set.add(_id_proxy(value))
+
+ def remove(self, value):
+ value = _id_proxy(value)
+ if value not in self:
+ raise KeyError(value.obj)
+ self.set.remove(value)
+
+ def discard(self, value):
+ self.set.discard(_id_proxy(value))
+
+ def pop(self):
+ proxied = self.set.pop()
+ return proxied.obj
+
+ def issubset(self, iterable):
+ if not isinstance(iterable, type(self)):
+ iterable = type(self)(iterable)
+ return self.set.issubset(iterable)
+ __le__ = issubset
+
+ def __lt__(self, iterable):
+ if not isinstance(iterable, type(self)):
+ iterable = type(self)(iterable)
+ return len(self) < len(iterable) and self.issubset(iterable)
+
+ def issuperset(self, iterable):
+ if not isinstance(iterable, type(self)):
+ iterable = type(self)(iterable)
+ return self.set.issuperset(iterable)
+ __ge__ = issuperset
+
+ def __gt__(self, iterable):
+ if not isinstance(iterable, type(self)):
+ iterable = type(self)(iterable)
+ return len(self) > len(iterable) and self.issuperset(iterable)
+
+ def __eq__(self, other):
+ if isinstance(other, IdentitySet):
+ return self.set == other.set
+ else:
+ return False
+
+ def __ne__(self, other):
+ if isinstance(other, IdentitySet):
+ return self.set != other.set
+ else:
+ return True
+
+ def __cmp__(self, other):
+ raise TypeError('cannot compare sets using cmp()')
+
+ def clear(self):
+ self.set.clear()
+
+ def copy(self):
+ return type(self)(self.set)
+
+ def union(self, iterable):
+ return type(self)(self.set.union(_proxyiter(iterable)))
+
+ def __or__(self, iterable):
+ if not isinstance(iterable, set_types + (IdentitySet,)):
+ return NotImplemented
+ return self.union(iterable)
+ __ror__ = union
+
+ def update(self, iterable):
+ self.set.update(_proxyiter(iterable))
+
+ def __ior__(self, iterable):
+ if not isinstance(iterable, set_types + (IdentitySet,)):
+ return NotImplemented
+ self.update(iterable)
+ return self
+
+ def difference(self, iterable):
+ return type(self)(self.set.difference(_proxyiter(iterable)))
+
+ def __sub__(self, iterable):
+ if not isinstance(iterable, set_types + (IdentitySet,)):
+ return NotImplemented
+ return self.difference(iterable)
+ __rsub__ = __sub__
+
+ def difference_update(self, iterable):
+ self.set.difference_update(_proxyiter(iterable))
+
+ def __isub__(self, iterable):
+ if not isinstance(iterable, set_types + (IdentitySet,)):
+ return NotImplemented
+ self.difference_update(iterable)
+ return self
+
+ def intersection(self, iterable):
+ return type(self)(self.set.intersection(_proxyiter(iterable)))
+
+ def __and__(self, iterable):
+ if not isinstance(iterable, set_types + (IdentitySet,)):
+ return NotImplemented
+ return self.intersection(iterable)
+ __rand__ = __and__
+
+ def intersection_update(self, iterable):
+ self.set.intersection_update(_proxyiter(iterable))
+
+ def __iand__(self, iterable):
+ if not isinstance(iterable, set_types + (IdentitySet,)):
+ return NotImplemented
+ self.intersection_update(iterable)
+ return self
+
+ def symmetric_difference(self, iterable):
+ return type(self)(self.set.symmetric_difference(_proxyiter(iterable)))
+
+ def __xor__(self, iterable):
+ if not isinstance(iterable, set_types + (IdentitySet,)):
+ return NotImplemented
+ return self.symmetric_difference(iterable)
+ __rxor__ = __xor__
+
+ def symmetric_difference_update(self, iterable):
+ self.set.symmetric_difference_update(_proxyiter(iterable))
+
+ def __ixor__(self, iterable):
+ if not isinstance(iterable, set_types + (IdentitySet,)):
+ return NotImplemented
+ self.symmetric_difference_update(iterable)
+ return self
+
+ def __iter__(self):
+ for proxy in self.set:
+ assert isinstance(proxy, self._IdentityProxy)
+ yield proxy.obj
+
+ def __len__(self):
+ return len(self.set)
+
+ def __contains__(self, value):
+ return _id_proxy(value) in self.set
+
+ def __hash__(self):
+ raise TypeError('set objects are unhashable')
+
+ def __repr__(self):
+ return '%s(%r)' % (type(self).__name__, list(self))
+
+def _proxyiter(iterable):
+ return itertools.imap(_id_proxy, iterable)
+
+def _id_proxy(item):
+ if isinstance(item, IdentitySet._IdentityProxy):
+ return item
+ else:
+ return IdentitySet._IdentityProxy(item)
+
+
class UniqueAppender(object):
"""appends items to a collection such that only unique items
are added."""
m1 = m2 = m3 = None
MyClass.dispose(MyClass)
assert len(util.ArgSingleton.instances) == 0
-
-
+
+class ImmutableSubclass(str):
+ pass
+
+class HashOverride(object):
+ def __init__(self, value=None):
+ self.value = value
+ def __hash__(self):
+ return hash(self.value)
+
+class EqOverride(object):
+ def __init__(self, value=None):
+ self.value = value
+ def __eq__(self, other):
+ if isinstance(other, EqOverride):
+ return self.value == other.value
+ else:
+ return False
+ def __ne__(self, other):
+ if isinstance(other, EqOverride):
+ return self.value != other.value
+ else:
+ return True
+
+class HashEqOverride(object):
+ def __init__(self, value=None):
+ self.value = value
+ def __hash__(self):
+ return hash(self.value)
+ def __eq__(self, other):
+ if isinstance(other, EqOverride):
+ return self.value == other.value
+ else:
+ return False
+ def __ne__(self, other):
+ if isinstance(other, EqOverride):
+ return self.value != other.value
+ else:
+ return True
+
+
+class IdentitySetTest(unittest.TestCase):
+ def assert_eq(self, identityset, expected_iterable):
+ found = sorted(list(identityset))
+ expected = sorted(expected_iterable)
+ self.assertEquals(found, expected)
+
+ def test_init(self):
+ ids = util.IdentitySet([1,2,3,2,1])
+ self.assert_eq(ids, [1,2,3])
+
+ ids = util.IdentitySet(ids)
+ self.assert_eq(ids, [1,2,3])
+
+ ids = util.IdentitySet()
+ self.assert_eq(ids, [])
+
+ ids = util.IdentitySet([])
+ self.assert_eq(ids, [])
+
+ ids = util.IdentitySet(ids)
+ self.assert_eq(ids, [])
+
+ def test_add(self):
+ for type_ in (object, ImmutableSubclass):
+ data = [type_(), type_()]
+ ids = util.IdentitySet()
+ for i in range(2) + range(2):
+ ids.add(data[i])
+ self.assert_eq(ids, data)
+
+ for type_ in (EqOverride, HashOverride, HashEqOverride):
+ data = [type_(1), type_(1), type_(2)]
+ ids = util.IdentitySet()
+ for i in range(3) + range(3):
+ ids.add(data[i])
+ self.assert_eq(ids, data)
+
+ def test_basic_sanity(self):
+ IdentitySet = util.IdentitySet
+
+ o1, o2, o3 = object(), object(), object()
+ ids = IdentitySet([o1])
+ ids.discard(o1)
+ ids.discard(o1)
+ ids.add(o1)
+ ids.remove(o1)
+ self.assertRaises(KeyError, ids.remove, o1)
+
+ self.assert_(ids.copy() == ids)
+ self.assert_(ids != None)
+ self.assert_(not(ids == None))
+ self.assert_(ids != IdentitySet([o1,o2,o3]))
+ ids.clear()
+ self.assert_(o1 not in ids)
+ ids.add(o2)
+ self.assert_(o2 in ids)
+ self.assert_(ids.pop() == o2)
+ ids.add(o1)
+ self.assert_(len(ids) == 1)
+
+ isuper = IdentitySet([o1,o2])
+ self.assert_(ids < isuper)
+ self.assert_(ids.issubset(isuper))
+ self.assert_(isuper.issuperset(ids))
+ self.assert_(isuper > ids)
+
+ self.assert_(ids.union(isuper) == isuper)
+ self.assert_(ids | isuper == isuper)
+ self.assert_(isuper - ids == IdentitySet([o2]))
+ self.assert_(isuper.difference(ids) == IdentitySet([o2]))
+ self.assert_(ids.intersection(isuper) == IdentitySet([o1]))
+ self.assert_(ids & isuper == IdentitySet([o1]))
+ self.assert_(ids.symmetric_difference(isuper) == IdentitySet([o2]))
+ self.assert_(ids ^ isuper == IdentitySet([o2]))
+
+ ids.update(isuper)
+ ids |= isuper
+ ids.difference_update(isuper)
+ ids -= isuper
+ ids.intersection_update(isuper)
+ ids &= isuper
+ ids.symmetric_difference_update(isuper)
+ ids ^= isuper
+
+ ids.update('foobar')
+ try:
+ ids |= 'foobar'
+ self.assert_(False)
+ except TypeError:
+ self.assert_(True)
+ s = set([o1,o2])
+ s |= ids
+ self.assert_(isinstance(s, IdentitySet))
+
+ self.assertRaises(TypeError, cmp, ids)
+ self.assertRaises(TypeError, hash, ids)
+
+
if __name__ == "__main__":
testbase.main()