"""
- _working_set = set
-
def __init__(self, iterable=None):
self._members = dict()
if iterable:
- for o in iterable:
- self.add(o)
+ self.update(iterable)
def add(self, value):
self._members[id(value)] = value
return True
def issubset(self, iterable):
- other = type(self)(iterable)
+ other = self.__class__(iterable)
if len(self) > len(other):
return False
return len(self) < len(other) and self.issubset(other)
def issuperset(self, iterable):
- other = type(self)(iterable)
+ other = self.__class__(iterable)
if len(self) < len(other):
return False
return len(self) > len(other) and self.issuperset(other)
def union(self, iterable):
- result = type(self)()
- # testlib.pragma exempt:__hash__
- members = self._member_id_tuples()
- other = _iter_id(iterable)
- result._members.update(self._working_set(members).union(other))
+ result = self.__class__()
+ members = self._members
+ result._members.update(members)
+ result._members.update((id(obj), obj) for obj in iterable)
return result
def __or__(self, other):
return self.union(other)
def update(self, iterable):
- self._members = self.union(iterable)._members
+ self._members.update((id(obj), obj) for obj in iterable)
def __ior__(self, other):
if not isinstance(other, IdentitySet):
return self
def difference(self, iterable):
- result = type(self)()
- # testlib.pragma exempt:__hash__
- members = self._member_id_tuples()
- other = _iter_id(iterable)
- result._members.update(self._working_set(members).difference(other))
+ result = self.__class__()
+ members = self._members
+ other = {id(obj) for obj in iterable}
+ result._members.update(
+ ((k, v) for k, v in members.items() if k not in other)
+ )
return result
def __sub__(self, other):
return self
def intersection(self, iterable):
- result = type(self)()
- # testlib.pragma exempt:__hash__
- members = self._member_id_tuples()
- other = _iter_id(iterable)
- result._members.update(self._working_set(members).intersection(other))
+ result = self.__class__()
+ members = self._members
+ other = {id(obj) for obj in iterable}
+ result._members.update(
+ (k, v) for k, v in members.items() if k in other
+ )
return result
def __and__(self, other):
return self
def symmetric_difference(self, iterable):
- result = type(self)()
- # testlib.pragma exempt:__hash__
- members = self._member_id_tuples()
- other = _iter_id(iterable)
+ result = self.__class__()
+ members = self._members
+ other = {id(obj): obj for obj in iterable}
result._members.update(
- self._working_set(members).symmetric_difference(other)
+ ((k, v) for k, v in members.items() if k not in other)
+ )
+ result._members.update(
+ ((k, v) for k, v in other.items() if k not in members)
)
return result
- def _member_id_tuples(self):
- return ((id(v), v) for v in self._members.values())
-
def __xor__(self, other):
if not isinstance(other, IdentitySet):
return NotImplemented
class OrderedIdentitySet(IdentitySet):
- class _working_set(OrderedSet):
- # a testing pragma: exempt the OIDS working set from the test suite's
- # "never call the user's __hash__" assertions. this is a big hammer,
- # but it's safe here: IDS operates on (id, instance) tuples in the
- # working set.
- __sa_hash_exempt__ = True
-
def __init__(self, iterable=None):
IdentitySet.__init__(self)
self._members = OrderedDict()
pass
-def _iter_id(iterable):
- """Generator: ((id(o), o) for o in iterable)."""
-
- for item in iterable:
- yield id(item), item
-
-
def has_dupes(sequence, target):
"""Given a sequence and search object, return True if there's more
than one, False if zero or one of them.
return hash(self.value)
+class NoHash(object):
+ def __init__(self, value=None):
+ self.value = value
+
+ __hash__ = None
+
+
class EqOverride(object):
def __init__(self, value=None):
self.value = value
class IdentitySetTest(fixtures.TestBase):
+ obj_type = object
+
def assert_eq(self, identityset, expected_iterable):
expected = sorted([id(o) for o in expected_iterable])
found = sorted([id(o) for o in identityset])
ids.add(data[i])
self.assert_eq(ids, data)
- for type_ in (EqOverride, HashOverride, HashEqOverride):
+ for type_ in (NoHash, EqOverride, HashOverride, HashEqOverride):
data = [type_(1), type_(1), type_(2)]
ids = util.IdentitySet()
for i in list(range(3)) + list(range(3)):
def test_dunder_sub2(self):
IdentitySet = util.IdentitySet
- o1, o2, o3 = object(), object(), object()
+ o1, o2, o3 = self.obj_type(), self.obj_type(), self.obj_type()
ids1 = IdentitySet([o1])
ids2 = IdentitySet([o1, o2, o3])
eq_(ids2 - ids1, IdentitySet([o2, o3]))
pass # TODO
def _create_sets(self):
- o1, o2, o3, o4, o5 = object(), object(), object(), object(), object()
+ o1, o2, o3, o4, o5 = (
+ self.obj_type(),
+ self.obj_type(),
+ self.obj_type(),
+ self.obj_type(),
+ self.obj_type(),
+ )
super_ = util.IdentitySet([o1, o2, o3])
sub_ = util.IdentitySet([o2])
twin1 = util.IdentitySet([o3])
def test_basic_sanity(self):
IdentitySet = util.IdentitySet
- o1, o2, o3 = object(), object(), object()
+ o1, o2, o3 = self.obj_type(), self.obj_type(), self.obj_type()
ids = IdentitySet([o1])
ids.discard(o1)
ids.discard(o1)
assert_raises(TypeError, hash, ids)
+class NoHashIdentitySetTest(IdentitySetTest):
+ obj_type = NoHash
+
+
class OrderedIdentitySetTest(fixtures.TestBase):
def assert_eq(self, identityset, expected_iterable):
expected = [id(o) for o in expected_iterable]