From: Mike Bayer Date: Mon, 4 May 2020 00:27:24 +0000 (-0400) Subject: Don't apply sets or similar to objects in IdentitySet X-Git-Tag: rel_1_4_0b1~348^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=35552e88ca798b809c7391bae11890c1557a3dd2;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Don't apply sets or similar to objects in IdentitySet Modified the internal "identity set" implementation, which is a set that hashes objects on their id() rather than their hash values, to not actually call the ``__hash__()`` method of the objects, which are typically user-mapped objects. Some methods were calling this method as a side effect of the implementation. Fixes: #5304 Change-Id: I0ed8762f47622215a54dcad9f210377b1becf8e8 --- diff --git a/doc/build/changelog/unreleased_13/5304.rst b/doc/build/changelog/unreleased_13/5304.rst new file mode 100644 index 0000000000..d08db88492 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5304.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, orm + :tickets: 5304 + + Modified the internal "identity set" implementation, which is a set that + hashes objects on their id() rather than their hash values, to not actually + call the ``__hash__()`` method of the objects, which are typically + user-mapped objects. Some methods were calling this method as a side + effect of the implementation. + diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index b21eb44cfb..10d80fc987 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -363,13 +363,10 @@ class IdentitySet(object): """ - _working_set = set - def __init__(self, iterable=None): self._members = dict() if iterable: - for o in iterable: - self.add(o) + self.update(iterable) def add(self, value): self._members[id(value)] = value @@ -412,7 +409,7 @@ class IdentitySet(object): return True def issubset(self, iterable): - other = type(self)(iterable) + other = self.__class__(iterable) if len(self) > len(other): return False @@ -433,7 +430,7 @@ class IdentitySet(object): return len(self) < len(other) and self.issubset(other) def issuperset(self, iterable): - other = type(self)(iterable) + other = self.__class__(iterable) if len(self) < len(other): return False @@ -455,11 +452,10 @@ class IdentitySet(object): return len(self) > len(other) and self.issuperset(other) def union(self, iterable): - result = type(self)() - # testlib.pragma exempt:__hash__ - members = self._member_id_tuples() - other = _iter_id(iterable) - result._members.update(self._working_set(members).union(other)) + result = self.__class__() + members = self._members + result._members.update(members) + result._members.update((id(obj), obj) for obj in iterable) return result def __or__(self, other): @@ -468,7 +464,7 @@ class IdentitySet(object): return self.union(other) def update(self, iterable): - self._members = self.union(iterable)._members + self._members.update((id(obj), obj) for obj in iterable) def __ior__(self, other): if not isinstance(other, IdentitySet): @@ -477,11 +473,12 @@ class IdentitySet(object): return self def difference(self, iterable): - result = type(self)() - # testlib.pragma exempt:__hash__ - members = self._member_id_tuples() - other = _iter_id(iterable) - result._members.update(self._working_set(members).difference(other)) + result = self.__class__() + members = self._members + other = {id(obj) for obj in iterable} + result._members.update( + ((k, v) for k, v in members.items() if k not in other) + ) return result def __sub__(self, other): @@ -499,11 +496,12 @@ class IdentitySet(object): return self def intersection(self, iterable): - result = type(self)() - # testlib.pragma exempt:__hash__ - members = self._member_id_tuples() - other = _iter_id(iterable) - result._members.update(self._working_set(members).intersection(other)) + result = self.__class__() + members = self._members + other = {id(obj) for obj in iterable} + result._members.update( + (k, v) for k, v in members.items() if k in other + ) return result def __and__(self, other): @@ -521,18 +519,17 @@ class IdentitySet(object): return self def symmetric_difference(self, iterable): - result = type(self)() - # testlib.pragma exempt:__hash__ - members = self._member_id_tuples() - other = _iter_id(iterable) + result = self.__class__() + members = self._members + other = {id(obj): obj for obj in iterable} result._members.update( - self._working_set(members).symmetric_difference(other) + ((k, v) for k, v in members.items() if k not in other) + ) + result._members.update( + ((k, v) for k, v in other.items() if k not in members) ) return result - def _member_id_tuples(self): - return ((id(v), v) for v in self._members.values()) - def __xor__(self, other): if not isinstance(other, IdentitySet): return NotImplemented @@ -600,13 +597,6 @@ class WeakSequence(object): class OrderedIdentitySet(IdentitySet): - class _working_set(OrderedSet): - # a testing pragma: exempt the OIDS working set from the test suite's - # "never call the user's __hash__" assertions. this is a big hammer, - # but it's safe here: IDS operates on (id, instance) tuples in the - # working set. - __sa_hash_exempt__ = True - def __init__(self, iterable=None): IdentitySet.__init__(self) self._members = OrderedDict() @@ -942,13 +932,6 @@ class ThreadLocalRegistry(ScopedRegistry): pass -def _iter_id(iterable): - """Generator: ((id(o), o) for o in iterable).""" - - for item in iterable: - yield id(item), item - - def has_dupes(sequence, target): """Given a sequence and search object, return True if there's more than one, False if zero or one of them. diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 96a9f955a0..a6d777c61d 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -1058,6 +1058,13 @@ class HashOverride(object): return hash(self.value) +class NoHash(object): + def __init__(self, value=None): + self.value = value + + __hash__ = None + + class EqOverride(object): def __init__(self, value=None): self.value = value @@ -1098,6 +1105,8 @@ class HashEqOverride(object): class IdentitySetTest(fixtures.TestBase): + obj_type = object + def assert_eq(self, identityset, expected_iterable): expected = sorted([id(o) for o in expected_iterable]) found = sorted([id(o) for o in identityset]) @@ -1127,7 +1136,7 @@ class IdentitySetTest(fixtures.TestBase): ids.add(data[i]) self.assert_eq(ids, data) - for type_ in (EqOverride, HashOverride, HashEqOverride): + for type_ in (NoHash, EqOverride, HashOverride, HashEqOverride): data = [type_(1), type_(1), type_(2)] ids = util.IdentitySet() for i in list(range(3)) + list(range(3)): @@ -1136,7 +1145,7 @@ class IdentitySetTest(fixtures.TestBase): def test_dunder_sub2(self): IdentitySet = util.IdentitySet - o1, o2, o3 = object(), object(), object() + o1, o2, o3 = self.obj_type(), self.obj_type(), self.obj_type() ids1 = IdentitySet([o1]) ids2 = IdentitySet([o1, o2, o3]) eq_(ids2 - ids1, IdentitySet([o2, o3])) @@ -1549,7 +1558,13 @@ class IdentitySetTest(fixtures.TestBase): pass # TODO def _create_sets(self): - o1, o2, o3, o4, o5 = object(), object(), object(), object(), object() + o1, o2, o3, o4, o5 = ( + self.obj_type(), + self.obj_type(), + self.obj_type(), + self.obj_type(), + self.obj_type(), + ) super_ = util.IdentitySet([o1, o2, o3]) sub_ = util.IdentitySet([o2]) twin1 = util.IdentitySet([o3]) @@ -1573,7 +1588,7 @@ class IdentitySetTest(fixtures.TestBase): def test_basic_sanity(self): IdentitySet = util.IdentitySet - o1, o2, o3 = object(), object(), object() + o1, o2, o3 = self.obj_type(), self.obj_type(), self.obj_type() ids = IdentitySet([o1]) ids.discard(o1) ids.discard(o1) @@ -1638,6 +1653,10 @@ class IdentitySetTest(fixtures.TestBase): assert_raises(TypeError, hash, ids) +class NoHashIdentitySetTest(IdentitySetTest): + obj_type = NoHash + + class OrderedIdentitySetTest(fixtures.TestBase): def assert_eq(self, identityset, expected_iterable): expected = [id(o) for o in expected_iterable]