]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-89967: make WeakKeyDictionary and WeakValueDictionary thread safe (#125325)
authorKumar Aditya <kumaraditya@python.org>
Sun, 13 Oct 2024 15:35:05 +0000 (21:05 +0530)
committerGitHub <noreply@github.com>
Sun, 13 Oct 2024 15:35:05 +0000 (21:05 +0530)
Make `WeakKeyDictionary` and `WeakValueDictionary` thread safe by copying the underlying the dict before iterating over it.

Lib/_weakrefset.py
Lib/weakref.py
Misc/NEWS.d/next/Library/2024-10-11-16-19-46.gh-issue-89967.vhWUOR.rst [new file with mode: 0644]

index 2071755d71dfc8459615cd9041e388c638922618..d1c7fcaeec9821c80861f7039d4aaca5a70644e6 100644 (file)
@@ -8,31 +8,6 @@ from types import GenericAlias
 __all__ = ['WeakSet']
 
 
-class _IterationGuard:
-    # This context manager registers itself in the current iterators of the
-    # weak container, such as to delay all removals until the context manager
-    # exits.
-    # This technique should be relatively thread-safe (since sets are).
-
-    def __init__(self, weakcontainer):
-        # Don't create cycles
-        self.weakcontainer = ref(weakcontainer)
-
-    def __enter__(self):
-        w = self.weakcontainer()
-        if w is not None:
-            w._iterating.add(self)
-        return self
-
-    def __exit__(self, e, t, b):
-        w = self.weakcontainer()
-        if w is not None:
-            s = w._iterating
-            s.remove(self)
-            if not s:
-                w._commit_removals()
-
-
 class WeakSet:
     def __init__(self, data=None):
         self.data = set()
index 25b70927e29c31d5fbe48beab554270c866c5829..94e4278143c9878a82aa1c397267309ea0ea16cc 100644 (file)
@@ -19,7 +19,7 @@ from _weakref import (
      ReferenceType,
      _remove_dead_weakref)
 
-from _weakrefset import WeakSet, _IterationGuard
+from _weakrefset import WeakSet
 
 import _collections_abc  # Import after _weakref to avoid circular import.
 import sys
@@ -105,34 +105,14 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
         def remove(wr, selfref=ref(self), _atomic_removal=_remove_dead_weakref):
             self = selfref()
             if self is not None:
-                if self._iterating:
-                    self._pending_removals.append(wr.key)
-                else:
-                    # Atomic removal is necessary since this function
-                    # can be called asynchronously by the GC
-                    _atomic_removal(self.data, wr.key)
+                # Atomic removal is necessary since this function
+                # can be called asynchronously by the GC
+                _atomic_removal(self.data, wr.key)
         self._remove = remove
-        # A list of keys to be removed
-        self._pending_removals = []
-        self._iterating = set()
         self.data = {}
         self.update(other, **kw)
 
-    def _commit_removals(self, _atomic_removal=_remove_dead_weakref):
-        pop = self._pending_removals.pop
-        d = self.data
-        # We shouldn't encounter any KeyError, because this method should
-        # always be called *before* mutating the dict.
-        while True:
-            try:
-                key = pop()
-            except IndexError:
-                return
-            _atomic_removal(d, key)
-
     def __getitem__(self, key):
-        if self._pending_removals:
-            self._commit_removals()
         o = self.data[key]()
         if o is None:
             raise KeyError(key)
@@ -140,18 +120,12 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
             return o
 
     def __delitem__(self, key):
-        if self._pending_removals:
-            self._commit_removals()
         del self.data[key]
 
     def __len__(self):
-        if self._pending_removals:
-            self._commit_removals()
         return len(self.data)
 
     def __contains__(self, key):
-        if self._pending_removals:
-            self._commit_removals()
         try:
             o = self.data[key]()
         except KeyError:
@@ -162,38 +136,28 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
         return "<%s at %#x>" % (self.__class__.__name__, id(self))
 
     def __setitem__(self, key, value):
-        if self._pending_removals:
-            self._commit_removals()
         self.data[key] = KeyedRef(value, self._remove, key)
 
     def copy(self):
-        if self._pending_removals:
-            self._commit_removals()
         new = WeakValueDictionary()
-        with _IterationGuard(self):
-            for key, wr in self.data.items():
-                o = wr()
-                if o is not None:
-                    new[key] = o
+        for key, wr in self.data.copy().items():
+            o = wr()
+            if o is not None:
+                new[key] = o
         return new
 
     __copy__ = copy
 
     def __deepcopy__(self, memo):
         from copy import deepcopy
-        if self._pending_removals:
-            self._commit_removals()
         new = self.__class__()
-        with _IterationGuard(self):
-            for key, wr in self.data.items():
-                o = wr()
-                if o is not None:
-                    new[deepcopy(key, memo)] = o
+        for key, wr in self.data.copy().items():
+            o = wr()
+            if o is not None:
+                new[deepcopy(key, memo)] = o
         return new
 
     def get(self, key, default=None):
-        if self._pending_removals:
-            self._commit_removals()
         try:
             wr = self.data[key]
         except KeyError:
@@ -207,21 +171,15 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
                 return o
 
     def items(self):
-        if self._pending_removals:
-            self._commit_removals()
-        with _IterationGuard(self):
-            for k, wr in self.data.items():
-                v = wr()
-                if v is not None:
-                    yield k, v
+        for k, wr in self.data.copy().items():
+            v = wr()
+            if v is not None:
+                yield k, v
 
     def keys(self):
-        if self._pending_removals:
-            self._commit_removals()
-        with _IterationGuard(self):
-            for k, wr in self.data.items():
-                if wr() is not None:
-                    yield k
+        for k, wr in self.data.copy().items():
+            if wr() is not None:
+                yield k
 
     __iter__ = keys
 
@@ -235,23 +193,15 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
         keep the values around longer than needed.
 
         """
-        if self._pending_removals:
-            self._commit_removals()
-        with _IterationGuard(self):
-            yield from self.data.values()
+        yield from self.data.copy().values()
 
     def values(self):
-        if self._pending_removals:
-            self._commit_removals()
-        with _IterationGuard(self):
-            for wr in self.data.values():
-                obj = wr()
-                if obj is not None:
-                    yield obj
+        for wr in self.data.copy().values():
+            obj = wr()
+            if obj is not None:
+                yield obj
 
     def popitem(self):
-        if self._pending_removals:
-            self._commit_removals()
         while True:
             key, wr = self.data.popitem()
             o = wr()
@@ -259,8 +209,6 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
                 return key, o
 
     def pop(self, key, *args):
-        if self._pending_removals:
-            self._commit_removals()
         try:
             o = self.data.pop(key)()
         except KeyError:
@@ -279,16 +227,12 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
         except KeyError:
             o = None
         if o is None:
-            if self._pending_removals:
-                self._commit_removals()
             self.data[key] = KeyedRef(default, self._remove, key)
             return default
         else:
             return o
 
     def update(self, other=None, /, **kwargs):
-        if self._pending_removals:
-            self._commit_removals()
         d = self.data
         if other is not None:
             if not hasattr(other, "items"):
@@ -308,9 +252,7 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
         keep the values around longer than needed.
 
         """
-        if self._pending_removals:
-            self._commit_removals()
-        return list(self.data.values())
+        return list(self.data.copy().values())
 
     def __ior__(self, other):
         self.update(other)
@@ -369,57 +311,22 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
         def remove(k, selfref=ref(self)):
             self = selfref()
             if self is not None:
-                if self._iterating:
-                    self._pending_removals.append(k)
-                else:
-                    try:
-                        del self.data[k]
-                    except KeyError:
-                        pass
+                try:
+                    del self.data[k]
+                except KeyError:
+                    pass
         self._remove = remove
-        # A list of dead weakrefs (keys to be removed)
-        self._pending_removals = []
-        self._iterating = set()
-        self._dirty_len = False
         if dict is not None:
             self.update(dict)
 
-    def _commit_removals(self):
-        # NOTE: We don't need to call this method before mutating the dict,
-        # because a dead weakref never compares equal to a live weakref,
-        # even if they happened to refer to equal objects.
-        # However, it means keys may already have been removed.
-        pop = self._pending_removals.pop
-        d = self.data
-        while True:
-            try:
-                key = pop()
-            except IndexError:
-                return
-
-            try:
-                del d[key]
-            except KeyError:
-                pass
-
-    def _scrub_removals(self):
-        d = self.data
-        self._pending_removals = [k for k in self._pending_removals if k in d]
-        self._dirty_len = False
-
     def __delitem__(self, key):
-        self._dirty_len = True
         del self.data[ref(key)]
 
     def __getitem__(self, key):
         return self.data[ref(key)]
 
     def __len__(self):
-        if self._dirty_len and self._pending_removals:
-            # self._pending_removals may still contain keys which were
-            # explicitly removed, we have to scrub them (see issue #21173).
-            self._scrub_removals()
-        return len(self.data) - len(self._pending_removals)
+        return len(self.data)
 
     def __repr__(self):
         return "<%s at %#x>" % (self.__class__.__name__, id(self))
@@ -429,11 +336,10 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
 
     def copy(self):
         new = WeakKeyDictionary()
-        with _IterationGuard(self):
-            for key, value in self.data.items():
-                o = key()
-                if o is not None:
-                    new[o] = value
+        for key, value in self.data.copy().items():
+            o = key()
+            if o is not None:
+                new[o] = value
         return new
 
     __copy__ = copy
@@ -441,11 +347,10 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
     def __deepcopy__(self, memo):
         from copy import deepcopy
         new = self.__class__()
-        with _IterationGuard(self):
-            for key, value in self.data.items():
-                o = key()
-                if o is not None:
-                    new[o] = deepcopy(value, memo)
+        for key, value in self.data.copy().items():
+            o = key()
+            if o is not None:
+                new[o] = deepcopy(value, memo)
         return new
 
     def get(self, key, default=None):
@@ -459,26 +364,23 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
         return wr in self.data
 
     def items(self):
-        with _IterationGuard(self):
-            for wr, value in self.data.items():
-                key = wr()
-                if key is not None:
-                    yield key, value
+        for wr, value in self.data.copy().items():
+            key = wr()
+            if key is not None:
+                yield key, value
 
     def keys(self):
-        with _IterationGuard(self):
-            for wr in self.data:
-                obj = wr()
-                if obj is not None:
-                    yield obj
+        for wr in self.data.copy():
+            obj = wr()
+            if obj is not None:
+                yield obj
 
     __iter__ = keys
 
     def values(self):
-        with _IterationGuard(self):
-            for wr, value in self.data.items():
-                if wr() is not None:
-                    yield value
+        for wr, value in self.data.copy().items():
+            if wr() is not None:
+                yield value
 
     def keyrefs(self):
         """Return a list of weak references to the keys.
@@ -493,7 +395,6 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
         return list(self.data)
 
     def popitem(self):
-        self._dirty_len = True
         while True:
             key, value = self.data.popitem()
             o = key()
@@ -501,7 +402,6 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
                 return o, value
 
     def pop(self, key, *args):
-        self._dirty_len = True
         return self.data.pop(ref(key), *args)
 
     def setdefault(self, key, default=None):
diff --git a/Misc/NEWS.d/next/Library/2024-10-11-16-19-46.gh-issue-89967.vhWUOR.rst b/Misc/NEWS.d/next/Library/2024-10-11-16-19-46.gh-issue-89967.vhWUOR.rst
new file mode 100644 (file)
index 0000000..d086045
--- /dev/null
@@ -0,0 +1 @@
+Make :class:`~weakref.WeakKeyDictionary` and :class:`~weakref.WeakValueDictionary` safe against concurrent mutations from other threads. Patch by Kumar Aditya.