]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Guard all indexed access in WeakInstanceDict
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 17 Jul 2017 15:06:22 +0000 (11:06 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 24 Jul 2017 15:48:13 +0000 (11:48 -0400)
Added ``KeyError`` checks to all methods within
:class:`.WeakInstanceDict` where a check for ``key in dict`` is
followed by indexed access to that key, to guard against a race against
garbage collection that under load can remove the key from the dict
after the code assumes its present, leading to very infrequent
``KeyError`` raises.

Change-Id: I881cc2899f7961d29a0549f44149a2615ae7a4ea
Fixes: #4030
doc/build/changelog/unreleased_11/4030.rst [new file with mode: 0644]
lib/sqlalchemy/orm/identity.py

diff --git a/doc/build/changelog/unreleased_11/4030.rst b/doc/build/changelog/unreleased_11/4030.rst
new file mode 100644 (file)
index 0000000..5b25b0a
--- /dev/null
@@ -0,0 +1,12 @@
+.. change:: 4030
+    :tags: bug, orm
+    :versions: 1.2.0b2
+    :tickets: 4030
+
+    Added ``KeyError`` checks to all methods within
+    :class:`.WeakInstanceDict` where a check for ``key in dict`` is
+    followed by indexed access to that key, to guard against a race against
+    garbage collection that under load can remove the key from the dict
+    after the code assumes its present, leading to very infrequent
+    ``KeyError`` raises.
+
index ca87fa20594ed69aba140cd726d7aecf165255ff..8f4304ad2633f70a2d31a5639cf1bf67a7c6c153 100644 (file)
@@ -105,15 +105,26 @@ class WeakInstanceDict(IdentityMap):
             return o is not None
 
     def contains_state(self, state):
-        return state.key in self._dict and self._dict[state.key] is state
+        if state.key in self._dict:
+            try:
+                return self._dict[state.key] is state
+            except KeyError:
+                return False
+        else:
+            return False
 
     def replace(self, state):
         if state.key in self._dict:
-            existing = self._dict[state.key]
-            if existing is not state:
-                self._manage_removed_state(existing)
+            try:
+                existing = self._dict[state.key]
+            except KeyError:
+                # catch gc removed the key after we just checked for it
+                pass
             else:
-                return
+                if existing is not state:
+                    self._manage_removed_state(existing)
+                else:
+                    return
 
         self._dict[state.key] = state
         self._manage_incoming_state(state)
@@ -124,6 +135,10 @@ class WeakInstanceDict(IdentityMap):
         if key in self._dict:
             try:
                 existing_state = self._dict[key]
+            except KeyError:
+                # catch gc removed the key after we just checked for it
+                pass
+            else:
                 if existing_state is not state:
                     o = existing_state.obj()
                     if o is not None:
@@ -134,8 +149,6 @@ class WeakInstanceDict(IdentityMap):
                                 orm_util.state_str(state), state.key))
                 else:
                     return False
-            except KeyError:
-                pass
         self._dict[key] = state
         self._manage_incoming_state(state)
         return True
@@ -148,11 +161,16 @@ class WeakInstanceDict(IdentityMap):
     def get(self, key, default=None):
         if key not in self._dict:
             return default
-        state = self._dict[key]
-        o = state.obj()
-        if o is None:
+        try:
+            state = self._dict[key]
+        except KeyError:
+            # catch gc removed the key after we just checked for it
             return default
-        return o
+        else:
+            o = state.obj()
+            if o is None:
+                return default
+            return o
 
     def items(self):
         values = self.all_states()
@@ -201,10 +219,15 @@ class WeakInstanceDict(IdentityMap):
 
     def safe_discard(self, state):
         if state.key in self._dict:
-            st = self._dict[state.key]
-            if st is state:
-                self._dict.pop(state.key, None)
-                self._manage_removed_state(state)
+            try:
+                st = self._dict[state.key]
+            except KeyError:
+                # catch gc removed the key after we just checked for it
+                pass
+            else:
+                if st is state:
+                    self._dict.pop(state.key, None)
+                    self._manage_removed_state(state)
 
     def prune(self):
         return 0