]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Don't apply sets or similar to objects in IdentitySet
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 4 May 2020 00:27:24 +0000 (20:27 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 4 May 2020 00:33:17 +0000 (20:33 -0400)
Modified the internal "identity set" implementation, which is a set that
hashes objects on their id() rather than their hash values, to not actually
call the ``__hash__()`` method of the objects, which are typically
user-mapped objects.  Some methods were calling this method as a side
effect of the implementation.

Fixes: #5304
Change-Id: I0ed8762f47622215a54dcad9f210377b1becf8e8

doc/build/changelog/unreleased_13/5304.rst [new file with mode: 0644]
lib/sqlalchemy/util/_collections.py
test/base/test_utils.py

diff --git a/doc/build/changelog/unreleased_13/5304.rst b/doc/build/changelog/unreleased_13/5304.rst
new file mode 100644 (file)
index 0000000..d08db88
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 5304
+
+    Modified the internal "identity set" implementation, which is a set that
+    hashes objects on their id() rather than their hash values, to not actually
+    call the ``__hash__()`` method of the objects, which are typically
+    user-mapped objects.  Some methods were calling this method as a side
+    effect of the implementation.
+
index b21eb44cfb19970b21cbf9055452151f4b12820c..10d80fc987c952ce4763212804f449e35266e099 100644 (file)
@@ -363,13 +363,10 @@ class IdentitySet(object):
 
     """
 
-    _working_set = set
-
     def __init__(self, iterable=None):
         self._members = dict()
         if iterable:
-            for o in iterable:
-                self.add(o)
+            self.update(iterable)
 
     def add(self, value):
         self._members[id(value)] = value
@@ -412,7 +409,7 @@ class IdentitySet(object):
             return True
 
     def issubset(self, iterable):
-        other = type(self)(iterable)
+        other = self.__class__(iterable)
 
         if len(self) > len(other):
             return False
@@ -433,7 +430,7 @@ class IdentitySet(object):
         return len(self) < len(other) and self.issubset(other)
 
     def issuperset(self, iterable):
-        other = type(self)(iterable)
+        other = self.__class__(iterable)
 
         if len(self) < len(other):
             return False
@@ -455,11 +452,10 @@ class IdentitySet(object):
         return len(self) > len(other) and self.issuperset(other)
 
     def union(self, iterable):
-        result = type(self)()
-        # testlib.pragma exempt:__hash__
-        members = self._member_id_tuples()
-        other = _iter_id(iterable)
-        result._members.update(self._working_set(members).union(other))
+        result = self.__class__()
+        members = self._members
+        result._members.update(members)
+        result._members.update((id(obj), obj) for obj in iterable)
         return result
 
     def __or__(self, other):
@@ -468,7 +464,7 @@ class IdentitySet(object):
         return self.union(other)
 
     def update(self, iterable):
-        self._members = self.union(iterable)._members
+        self._members.update((id(obj), obj) for obj in iterable)
 
     def __ior__(self, other):
         if not isinstance(other, IdentitySet):
@@ -477,11 +473,12 @@ class IdentitySet(object):
         return self
 
     def difference(self, iterable):
-        result = type(self)()
-        # testlib.pragma exempt:__hash__
-        members = self._member_id_tuples()
-        other = _iter_id(iterable)
-        result._members.update(self._working_set(members).difference(other))
+        result = self.__class__()
+        members = self._members
+        other = {id(obj) for obj in iterable}
+        result._members.update(
+            ((k, v) for k, v in members.items() if k not in other)
+        )
         return result
 
     def __sub__(self, other):
@@ -499,11 +496,12 @@ class IdentitySet(object):
         return self
 
     def intersection(self, iterable):
-        result = type(self)()
-        # testlib.pragma exempt:__hash__
-        members = self._member_id_tuples()
-        other = _iter_id(iterable)
-        result._members.update(self._working_set(members).intersection(other))
+        result = self.__class__()
+        members = self._members
+        other = {id(obj) for obj in iterable}
+        result._members.update(
+            (k, v) for k, v in members.items() if k in other
+        )
         return result
 
     def __and__(self, other):
@@ -521,18 +519,17 @@ class IdentitySet(object):
         return self
 
     def symmetric_difference(self, iterable):
-        result = type(self)()
-        # testlib.pragma exempt:__hash__
-        members = self._member_id_tuples()
-        other = _iter_id(iterable)
+        result = self.__class__()
+        members = self._members
+        other = {id(obj): obj for obj in iterable}
         result._members.update(
-            self._working_set(members).symmetric_difference(other)
+            ((k, v) for k, v in members.items() if k not in other)
+        )
+        result._members.update(
+            ((k, v) for k, v in other.items() if k not in members)
         )
         return result
 
-    def _member_id_tuples(self):
-        return ((id(v), v) for v in self._members.values())
-
     def __xor__(self, other):
         if not isinstance(other, IdentitySet):
             return NotImplemented
@@ -600,13 +597,6 @@ class WeakSequence(object):
 
 
 class OrderedIdentitySet(IdentitySet):
-    class _working_set(OrderedSet):
-        # a testing pragma: exempt the OIDS working set from the test suite's
-        # "never call the user's __hash__" assertions.  this is a big hammer,
-        # but it's safe here: IDS operates on (id, instance) tuples in the
-        # working set.
-        __sa_hash_exempt__ = True
-
     def __init__(self, iterable=None):
         IdentitySet.__init__(self)
         self._members = OrderedDict()
@@ -942,13 +932,6 @@ class ThreadLocalRegistry(ScopedRegistry):
             pass
 
 
-def _iter_id(iterable):
-    """Generator: ((id(o), o) for o in iterable)."""
-
-    for item in iterable:
-        yield id(item), item
-
-
 def has_dupes(sequence, target):
     """Given a sequence and search object, return True if there's more
     than one, False if zero or one of them.
index 96a9f955a07ce3372e2b8b16bfaf12b1cfdca6a0..a6d777c61df44b1417c204da017e7fcb517a0b2a 100644 (file)
@@ -1058,6 +1058,13 @@ class HashOverride(object):
         return hash(self.value)
 
 
+class NoHash(object):
+    def __init__(self, value=None):
+        self.value = value
+
+    __hash__ = None
+
+
 class EqOverride(object):
     def __init__(self, value=None):
         self.value = value
@@ -1098,6 +1105,8 @@ class HashEqOverride(object):
 
 
 class IdentitySetTest(fixtures.TestBase):
+    obj_type = object
+
     def assert_eq(self, identityset, expected_iterable):
         expected = sorted([id(o) for o in expected_iterable])
         found = sorted([id(o) for o in identityset])
@@ -1127,7 +1136,7 @@ class IdentitySetTest(fixtures.TestBase):
                 ids.add(data[i])
             self.assert_eq(ids, data)
 
-        for type_ in (EqOverride, HashOverride, HashEqOverride):
+        for type_ in (NoHash, EqOverride, HashOverride, HashEqOverride):
             data = [type_(1), type_(1), type_(2)]
             ids = util.IdentitySet()
             for i in list(range(3)) + list(range(3)):
@@ -1136,7 +1145,7 @@ class IdentitySetTest(fixtures.TestBase):
 
     def test_dunder_sub2(self):
         IdentitySet = util.IdentitySet
-        o1, o2, o3 = object(), object(), object()
+        o1, o2, o3 = self.obj_type(), self.obj_type(), self.obj_type()
         ids1 = IdentitySet([o1])
         ids2 = IdentitySet([o1, o2, o3])
         eq_(ids2 - ids1, IdentitySet([o2, o3]))
@@ -1549,7 +1558,13 @@ class IdentitySetTest(fixtures.TestBase):
         pass  # TODO
 
     def _create_sets(self):
-        o1, o2, o3, o4, o5 = object(), object(), object(), object(), object()
+        o1, o2, o3, o4, o5 = (
+            self.obj_type(),
+            self.obj_type(),
+            self.obj_type(),
+            self.obj_type(),
+            self.obj_type(),
+        )
         super_ = util.IdentitySet([o1, o2, o3])
         sub_ = util.IdentitySet([o2])
         twin1 = util.IdentitySet([o3])
@@ -1573,7 +1588,7 @@ class IdentitySetTest(fixtures.TestBase):
     def test_basic_sanity(self):
         IdentitySet = util.IdentitySet
 
-        o1, o2, o3 = object(), object(), object()
+        o1, o2, o3 = self.obj_type(), self.obj_type(), self.obj_type()
         ids = IdentitySet([o1])
         ids.discard(o1)
         ids.discard(o1)
@@ -1638,6 +1653,10 @@ class IdentitySetTest(fixtures.TestBase):
         assert_raises(TypeError, hash, ids)
 
 
+class NoHashIdentitySetTest(IdentitySetTest):
+    obj_type = NoHash
+
+
 class OrderedIdentitySetTest(fixtures.TestBase):
     def assert_eq(self, identityset, expected_iterable):
         expected = [id(o) for o in expected_iterable]